diff --git a/.gitignore b/.gitignore index 3014a3b..68359b1 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,8 @@ settings.json .DS_Store -.vscode/alive \ No newline at end of file +.vscode/alive + +# Added by cargo + +/target diff --git a/.vscode/launch.json b/.vscode/launch.json index 3b9c326..a6c15e4 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,55 @@ "name": "Debug", "program": "${workspaceFolder}/gemla/target/debug/gemla.exe", "args": ["./gemla/temp/"], - "cwd": "${workspaceFolder}" + "cwd": "${workspaceFolder}/gemla" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug Rust Tests", + "cargo": { + "args": [ + "test", + "--manifest-path", "${workspaceFolder}/gemla/Cargo.toml", + "--no-run", // Compiles the tests without running them + "--package=gemla", // Specify your package name if necessary + "--bin=bin" + ], + "filter": { } + }, + "args": [], + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug gemla Lib Tests", + "cargo": { + "args": [ + "test", + "--manifest-path", "${workspaceFolder}/gemla/Cargo.toml", + "--no-run", // Compiles the tests without running them + "--package=gemla", // Specify your package name if necessary + "--lib" + ], + "filter": { } + }, + "args": [], + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug Rust FileLinked Tests", + "cargo": { + "args": [ + "test", + "--manifest-path", "${workspaceFolder}/file_linked/Cargo.toml", + "--no-run", // Compiles the tests without running them + "--package=file_linked", // Specify your package name if necessary + "--lib" + ], + "filter": { } + }, + "args": [], } ] } \ No newline at end of file diff --git a/analyze_data.py b/analyze_data.py new file mode 100644 index 0000000..2eba59d --- /dev/null +++ b/analyze_data.py @@ -0,0 +1,171 @@ +# Re-importing necessary libraries +import json +import matplotlib.pyplot as plt +from collections import defaultdict +import numpy as np + +# Simplified JSON data for demonstration +with open('gemla/round4.json', 'r') as file: + simplified_json_data = json.load(file) + +target_node_id = '523f8250-3101-4586-90a1-127ffa6d73d9' + +# Function to traverse the tree to find a node id +def traverse_left_nodes(node): + if node is None: + return [] + + left_node = node.get("left") + if left_node is None: + return [node] + + return [node] + traverse_left_nodes(left_node) + +# Function to traverse the tree to find a node id +def traverse_right_nodes(node): + if node is None: + return [] + + right_node = node.get("right") + left_node = node.get("left") + + if right_node is None and left_node is None: + return [] + elif right_node and left_node: + return [right_node] + traverse_right_nodes(left_node) + + return [] + + +# Getting the left graph +left_nodes = traverse_left_nodes(simplified_json_data[0]) +left_nodes.reverse() +# print(node) +# Print properties available on the first node +node = left_nodes[0] +# print(node["val"].keys()) + +scores = [] +for node in left_nodes: + # print(node) + # print(f'Node ID: {node["val"]["id"]}') + # print(f'Node scores length: {len(node["val"]["node"]["scores"])}') + if node["val"]["node"]: + node_scores = node["val"]["node"]["scores"] + if node_scores: + for score in node_scores: + scores.append(score) + +# print(scores) + +scores_values = [list(score_set.values()) for score_set in scores] + +# Set up the figure for plotting on the same graph +fig, ax = plt.subplots(figsize=(10, 6)) + +# Generate a boxplot for each set of scores on the same graph +boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))]) + +# Set figure name to node id +ax.set_xscale('symlog', linthresh=1.0) + +# Labeling +ax.set_xlabel(f'Scores - Main Line') +ax.set_ylabel('Score Sets') +ax.yaxis.grid(True) # Add horizontal grid lines for clarity + +# Set y-axis labels to be visible +ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))]) + +# Getting most recent right graph +right_nodes = traverse_right_nodes(simplified_json_data[0]) +if len(right_nodes) != 0: + target_node_id = None + target_node = None + if target_node_id: + for node in right_nodes: + if node["val"]["id"] == target_node_id: + target_node = node + break + else: + target_node = right_nodes[0] + scores = target_node["val"]["node"]["scores"] + + scores_values = [list(score_set.values()) for score_set in scores] + + # Set up the figure for plotting on the same graph + fig, ax = plt.subplots(figsize=(10, 6)) + + # Generate a boxplot for each set of scores on the same graph + boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))]) + + ax.set_xscale('symlog', linthresh=1.0) + + # Labeling + ax.set_xlabel(f'Scores: {target_node['val']['id']}') + ax.set_ylabel('Score Sets') + ax.yaxis.grid(True) # Add horizontal grid lines for clarity + + # Set y-axis labels to be visible + ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))]) + +# Find the highest scoring sets combining all scores and generations +scores = [] +for node in left_nodes: + if node["val"]["node"]: + node_scores = node["val"]["node"]["scores"] + translated_node_scores = [] + if node_scores: + for i in range(len(node_scores)): + for (individual, score) in node_scores[i].items(): + translated_node_scores.append((node["val"]["id"], i, score)) + + scores.append(translated_node_scores) + +# Add scores from the right nodes +if len(right_nodes) != 0: + for node in right_nodes: + if node["val"]["node"]: + node_scores = node["val"]["node"]["scores"] + translated_node_scores = [] + if node_scores: + for i in range(len(node_scores)): + for (individual, score) in node_scores[i].items(): + translated_node_scores.append((node["val"]["id"], i, score)) + scores.append(translated_node_scores) + +# Organize scores by individual and then by generation +individual_generation_scores = defaultdict(lambda: defaultdict(list)) +for sublist in scores: + for id, generation, score in sublist: + individual_generation_scores[id][generation].append(score) + +# Calculate Q3 for each individual's generation +individual_generation_q3 = {} +for id, generations in individual_generation_scores.items(): + for gen, scores in generations.items(): + individual_generation_q3[(id, gen)] = np.percentile(scores, 75) + +# Sort by Q3 value, highest first, and select the top 20 +top_20_individual_generations = sorted(individual_generation_q3, key=individual_generation_q3.get, reverse=True)[:40] + +# Prepare scores for the top 20 for plotting +top_20_scores = [individual_generation_scores[id][gen] for id, gen in top_20_individual_generations] + +# Adjust labels for clarity, indicating both the individual ID and generation +labels = [f'{id[:8]}... Gen {gen}' for id, gen in top_20_individual_generations] + +# Generate box and whisker plots for the top 20 individual generations +fig, ax = plt.subplots(figsize=(12, 10)) +ax.boxplot(top_20_scores, vert=False, patch_artist=True, labels=labels) + +ax.set_xscale('symlog', linthresh=1.0) + +ax.set_xlabel('Scores') +ax.set_ylabel('Individual Generation') +ax.set_title('Top 20 Individual Generations by Q3 Value') +ax.yaxis.grid(True) # Add horizontal grid lines for clarity + +# Display the plot +plt.show() + diff --git a/carp_spike/.gitignore b/carp_spike/.gitignore deleted file mode 100644 index 466e248..0000000 --- a/carp_spike/.gitignore +++ /dev/null @@ -1 +0,0 @@ -out/ \ No newline at end of file diff --git a/carp_spike/main.carp b/carp_spike/main.carp deleted file mode 100644 index 344241e..0000000 --- a/carp_spike/main.carp +++ /dev/null @@ -1,10 +0,0 @@ -(use Random) -(Project.config "title" "gemla") - -(deftype SimulationNode [population-size Int, population-cutoff Int]) - -;; (let [test (SimulationNode.init 10 3)] -;; (do -;; (SimulationNode.set-population-size test 20) -;; (SimulationNode.population-size &test) -;; )) \ No newline at end of file diff --git a/extract_fann_data/Cargo.toml b/extract_fann_data/Cargo.toml new file mode 100644 index 0000000..c324aaf --- /dev/null +++ b/extract_fann_data/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "extract_fann_data" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +fann = "0.1.8" diff --git a/extract_fann_data/build.rs b/extract_fann_data/build.rs new file mode 100644 index 0000000..e6b8ca6 --- /dev/null +++ b/extract_fann_data/build.rs @@ -0,0 +1,11 @@ +fn main() { + // Replace this with the path to the directory containing `fann.lib` + let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib"; + + println!("cargo:rustc-link-search=native={}", lib_dir); + println!("cargo:rustc-link-lib=static=fann"); + // Use `dylib=fann` instead of `static=fann` if you're linking dynamically + + // If there are any additional directories where the compiler can find header files, you can specify them like this: + // println!("cargo:include={}", path_to_include_directory); +} diff --git a/extract_fann_data/src/main.rs b/extract_fann_data/src/main.rs new file mode 100644 index 0000000..8d6f03b --- /dev/null +++ b/extract_fann_data/src/main.rs @@ -0,0 +1,38 @@ +extern crate fann; + +use fann::Fann; +use std::os::raw::c_uint; + +fn main() { + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let network_file = &args[1]; + match Fann::from_file(network_file) { + Ok(ann) => { + // Output layer sizes + let layer_sizes = ann.get_layer_sizes(); + let bias_counts = ann.get_bias_counts(); + + println!("Layers:"); + for (layer_size, bias_count) in layer_sizes.iter().zip(bias_counts.iter()) { + println!("{} {}", layer_size, bias_count); + } + + // Output connections + println!("Connections:"); + let connections = ann.get_connections(); + + for connection in connections { + println!("{} {} {}", connection.from_neuron, connection.to_neuron, connection.weight); + } + }, + Err(err) => { + eprintln!("Error loading network from file {}: {}", network_file, err); + std::process::exit(1); + } + } +} diff --git a/file_linked/Cargo.toml b/file_linked/Cargo.toml index abf3367..14ba835 100644 --- a/file_linked/Cargo.toml +++ b/file_linked/Cargo.toml @@ -19,4 +19,7 @@ serde = { version = "1.0", features = ["derive"] } thiserror = "1.0" anyhow = "1.0" bincode = "1.3.3" -log = "0.4.14" \ No newline at end of file +log = "0.4.14" +serde_json = "1.0.114" +tokio = { version = "1.37.0", features = ["full"] } +futures = "0.3.30" diff --git a/file_linked/src/constants/data_format.rs b/file_linked/src/constants/data_format.rs new file mode 100644 index 0000000..9a9940e --- /dev/null +++ b/file_linked/src/constants/data_format.rs @@ -0,0 +1,5 @@ +#[derive(Debug)] +pub enum DataFormat { + Bincode, + Json, +} diff --git a/file_linked/src/constants/mod.rs b/file_linked/src/constants/mod.rs new file mode 100644 index 0000000..d17427a --- /dev/null +++ b/file_linked/src/constants/mod.rs @@ -0,0 +1 @@ +pub mod data_format; diff --git a/file_linked/src/lib.rs b/file_linked/src/lib.rs index 5b73c1d..694b114 100644 --- a/file_linked/src/lib.rs +++ b/file_linked/src/lib.rs @@ -1,8 +1,10 @@ //! A wrapper around an object that ties it to a physical file +pub mod constants; pub mod error; use anyhow::{anyhow, Context}; +use constants::data_format::DataFormat; use error::Error; use log::info; use serde::{de::DeserializeOwned, Serialize}; @@ -10,9 +12,10 @@ use std::{ fs::{copy, remove_file, File}, io::{ErrorKind, Write}, path::{Path, PathBuf}, - thread, - thread::JoinHandle, + sync::Arc, + thread::{self, JoinHandle}, }; +use tokio::sync::RwLock; /// A wrapper around an object `T` that ties the object to a physical file #[derive(Debug)] @@ -20,10 +23,11 @@ pub struct FileLinked where T: Serialize, { - val: T, + val: Arc>, path: PathBuf, temp_file_path: PathBuf, file_thread: Option>, + data_format: DataFormat, } impl Drop for FileLinked @@ -48,10 +52,12 @@ where /// # Examples /// ``` /// # use file_linked::*; + /// # use file_linked::constants::data_format::DataFormat; /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -60,27 +66,30 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() { + /// # #[tokio::main] + /// # async fn main() { /// let test = Test { /// a: 1, /// b: String::from("two"), /// c: 3.0 /// }; /// - /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp")) + /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); - /// assert_eq!(linked_test.readonly().b, String::from("two")); - /// assert_eq!(linked_test.readonly().c, 3.0); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); + /// assert_eq!(readonly_ref.b, String::from("two")); + /// assert_eq!(readonly_ref.c, 3.0); /// # /// # drop(linked_test); /// # /// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # } /// ``` - pub fn readonly(&self) -> &T { - &self.val + pub fn readonly(&self) -> Arc> { + self.val.clone() } /// Creates a new [`FileLinked`] object of type `T` stored to the file given by `path`. @@ -88,10 +97,12 @@ where /// # Examples /// ``` /// # use file_linked::*; + /// # use file_linked::constants::data_format::DataFormat; /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// #[derive(Deserialize, Serialize)] /// struct Test { @@ -100,26 +111,29 @@ where /// pub c: f64 /// } /// - /// # fn main() { + /// #[tokio::main] + /// # async fn main() { /// let test = Test { /// a: 1, /// b: String::from("two"), /// c: 3.0 /// }; /// - /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp")) + /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); - /// assert_eq!(linked_test.readonly().b, String::from("two")); - /// assert_eq!(linked_test.readonly().c, 3.0); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); + /// assert_eq!(readonly_ref.b, String::from("two")); + /// assert_eq!(readonly_ref.c, 3.0); /// # /// # drop(linked_test); /// # /// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # } /// ``` - pub fn new(val: T, path: &Path) -> Result, Error> { + pub async fn new(val: T, path: &Path, data_format: DataFormat) -> Result, Error> { let mut temp_file_path = path.to_path_buf(); temp_file_path.set_file_name(format!( ".temp{}", @@ -130,21 +144,28 @@ where )); let mut result = FileLinked { - val, + val: Arc::new(RwLock::new(val)), path: path.to_path_buf(), temp_file_path, file_thread: None, + data_format, }; - result.write_data()?; + result.write_data().await?; Ok(result) } - fn write_data(&mut self) -> Result<(), Error> { + async fn write_data(&mut self) -> Result<(), Error> { let thread_path = self.path.clone(); let thread_temp_path = self.temp_file_path.clone(); - let thread_val = bincode::serialize(&self.val) - .with_context(|| "Unable to serialize object into bincode".to_string())?; + let val = self.val.read().await; + + let thread_val = match self.data_format { + DataFormat::Bincode => bincode::serialize(&*val) + .with_context(|| "Unable to serialize object into bincode".to_string())?, + DataFormat::Json => serde_json::to_vec(&*val) + .with_context(|| "Unable to serialize object into JSON".to_string())?, + }; if let Some(file_thread) = self.file_thread.take() { file_thread @@ -190,10 +211,12 @@ where /// ``` /// # use file_linked::*; /// # use file_linked::error::Error; + /// # use file_linked::constants::data_format::DataFormat; /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -202,21 +225,28 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() -> Result<(), Error> { + /// # #[tokio::main] + /// # async fn main() -> Result<(), Error> { /// let test = Test { /// a: 1, /// b: String::from(""), /// c: 0.0 /// }; /// - /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp")) + /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); + /// { + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); + /// } /// - /// linked_test.mutate(|t| t.a = 2)?; + /// linked_test.mutate(|t| t.a = 2).await?; /// - /// assert_eq!(linked_test.readonly().a, 2); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 2); /// # /// # drop(linked_test); /// # @@ -225,10 +255,15 @@ where /// # Ok(()) /// # } /// ``` - pub fn mutate U>(&mut self, op: F) -> Result { - let result = op(&mut self.val); + pub async fn mutate U>(&mut self, op: F) -> Result { + let val_clone = self.val.clone(); // Arc> + let mut val = val_clone.write().await; // RwLockWriteGuard - self.write_data()?; + let result = op(&mut val); + + drop(val); + + self.write_data().await?; Ok(result) } @@ -239,10 +274,12 @@ where /// ``` /// # use file_linked::*; /// # use file_linked::error::Error; + /// # use file_linked::constants::data_format::DataFormat; /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -251,25 +288,30 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() -> Result<(), Error> { + /// # #[tokio::main] + /// # async fn main() -> Result<(), Error> { /// let test = Test { /// a: 1, /// b: String::from(""), /// c: 0.0 /// }; /// - /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp")) + /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, 1); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 1); /// /// linked_test.replace(Test { /// a: 2, /// b: String::from(""), /// c: 0.0 - /// })?; + /// }).await?; /// - /// assert_eq!(linked_test.readonly().a, 2); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, 2); /// # /// # drop(linked_test); /// # @@ -278,10 +320,30 @@ where /// # Ok(()) /// # } /// ``` - pub fn replace(&mut self, val: T) -> Result<(), Error> { - self.val = val; + pub async fn replace(&mut self, val: T) -> Result<(), Error> { + self.val = Arc::new(RwLock::new(val)); - self.write_data() + self.write_data().await + } +} + +impl FileLinked +where + T: Serialize + DeserializeOwned + Send + 'static, +{ + /// Asynchronously modifies the data contained in a `FileLinked` object using an async callback `op`. + pub async fn mutate_async(&mut self, op: F) -> Result + where + F: FnOnce(Arc>) -> Fut, + Fut: std::future::Future + Send, + U: Send, + { + let val_clone = self.val.clone(); + let result = op(val_clone).await; + + self.write_data().await?; + + Ok(result) } } @@ -295,6 +357,7 @@ where /// ``` /// # use file_linked::*; /// # use file_linked::error::Error; + /// # use file_linked::constants::data_format::DataFormat; /// # use serde::{Deserialize, Serialize}; /// # use std::fmt; /// # use std::string::ToString; @@ -302,6 +365,7 @@ where /// # use std::fs::OpenOptions; /// # use std::io::Write; /// # use std::path::PathBuf; + /// # use tokio; /// # /// # #[derive(Deserialize, Serialize)] /// # struct Test { @@ -310,7 +374,8 @@ where /// # pub c: f64 /// # } /// # - /// # fn main() -> Result<(), Error> { + /// # #[tokio::main] + /// # async fn main() -> Result<(), Error> { /// let test = Test { /// a: 1, /// b: String::from("2"), @@ -327,12 +392,14 @@ where /// /// bincode::serialize_into(file, &test).expect("Unable to serialize object"); /// - /// let mut linked_test = FileLinked::::from_file(&path) + /// let mut linked_test = FileLinked::::from_file(&path, DataFormat::Bincode) /// .expect("Unable to create file linked object"); /// - /// assert_eq!(linked_test.readonly().a, test.a); - /// assert_eq!(linked_test.readonly().b, test.b); - /// assert_eq!(linked_test.readonly().c, test.c); + /// let readonly = linked_test.readonly(); + /// let readonly_ref = readonly.read().await; + /// assert_eq!(readonly_ref.a, test.a); + /// assert_eq!(readonly_ref.b, test.b); + /// assert_eq!(readonly_ref.c, test.c); /// # /// # drop(linked_test); /// # @@ -341,7 +408,7 @@ where /// # Ok(()) /// # } /// ``` - pub fn from_file(path: &Path) -> Result, Error> { + pub fn from_file(path: &Path, data_format: DataFormat) -> Result, Error> { let mut temp_file_path = path.to_path_buf(); temp_file_path.set_file_name(format!( ".temp{}", @@ -351,16 +418,22 @@ where .ok_or_else(|| anyhow!("Unable to get filename for tempfile {}", path.display()))? )); - match File::open(path).map_err(Error::from).and_then(|file| { - bincode::deserialize_from::(file) - .with_context(|| format!("Unable to deserialize file {}", path.display())) - .map_err(Error::from) - }) { + match File::open(path) + .map_err(Error::from) + .and_then(|file| match data_format { + DataFormat::Bincode => bincode::deserialize_from::(file) + .with_context(|| format!("Unable to deserialize file {}", path.display())) + .map_err(Error::from), + DataFormat::Json => serde_json::from_reader(file) + .with_context(|| format!("Unable to deserialize file {}", path.display())) + .map_err(Error::from), + }) { Ok(val) => Ok(FileLinked { - val, + val: Arc::new(RwLock::new(val)), path: path.to_path_buf(), temp_file_path, file_thread: None, + data_format, }), Err(err) => { info!( @@ -370,30 +443,43 @@ where ); // Try to use temp file instead and see if that file exists and is serializable - let val = FileLinked::from_temp_file(&temp_file_path, path) + let val = FileLinked::from_temp_file(&temp_file_path, path, &data_format) .map_err(|_| err) .with_context(|| format!("Failed to read/deserialize the object from the file {} and temp file {}", path.display(), temp_file_path.display()))?; Ok(FileLinked { - val, + val: Arc::new(RwLock::new(val)), path: path.to_path_buf(), temp_file_path, file_thread: None, + data_format, }) } } } - fn from_temp_file(temp_file_path: &Path, path: &Path) -> Result { + fn from_temp_file( + temp_file_path: &Path, + path: &Path, + data_format: &DataFormat, + ) -> Result { let file = File::open(temp_file_path) .with_context(|| format!("Unable to open file {}", temp_file_path.display()))?; - let val = bincode::deserialize_from(file).with_context(|| { - format!( - "Could not deserialize from temp file {}", - temp_file_path.display() - ) - })?; + let val = match data_format { + DataFormat::Bincode => bincode::deserialize_from(file).with_context(|| { + format!( + "Could not deserialize from temp file {}", + temp_file_path.display() + ) + })?, + DataFormat::Json => serde_json::from_reader(file).with_context(|| { + format!( + "Could not deserialize from temp file {}", + temp_file_path.display() + ) + })?, + }; info!("Successfully deserialized value from temp file"); @@ -421,8 +507,12 @@ mod tests { } } - pub fn run Result<(), Error>>(&self, op: F) -> Result<(), Error> { - op(&self.path) + pub async fn run(&self, op: F) -> () + where + F: FnOnce(PathBuf) -> Fut, + Fut: std::future::Future, + { + op(self.path.clone()).await } } @@ -434,92 +524,173 @@ mod tests { } } - #[test] - fn test_readonly() -> Result<(), Error> { + #[tokio::test] + async fn test_readonly() { let path = PathBuf::from("test_readonly"); let cleanup = CleanUp::new(&path); - cleanup.run(|p| { - let val = vec!["one", "two", ""]; + cleanup + .run(|p| async move { + let val = vec!["one", "two", ""]; - let linked_object = FileLinked::new(val.clone(), &p)?; - assert_eq!(*linked_object.readonly(), val); - - Ok(()) - }) + let linked_object = FileLinked::new(val.clone(), &p, DataFormat::Json) + .await + .expect("Unable to create file linked object"); + let linked_object_arc = linked_object.readonly(); + let linked_object_ref = linked_object_arc.read().await; + assert_eq!(*linked_object_ref, val); + }) + .await; } - #[test] - fn test_new() -> Result<(), Error> { + #[tokio::test] + async fn test_new() { let path = PathBuf::from("test_new"); let cleanup = CleanUp::new(&path); - cleanup.run(|p| { - let val = "test"; + cleanup + .run(|p| async move { + let val = "test"; - FileLinked::new(val, &p)?; + FileLinked::new(val, &p, DataFormat::Bincode) + .await + .expect("Unable to create file linked object"); - let file = File::open(&p)?; - let result: String = - bincode::deserialize_from(file).expect("Unable to deserialize from file"); - assert_eq!(result, val); - - Ok(()) - }) + let file = File::open(&p).expect("Unable to open file"); + let result: String = + bincode::deserialize_from(file).expect("Unable to deserialize from file"); + assert_eq!(result, val); + }) + .await; } - #[test] - fn test_mutate() -> Result<(), Error> { + #[tokio::test] + async fn test_mutate() { let path = PathBuf::from("test_mutate"); let cleanup = CleanUp::new(&path); - cleanup.run(|p| { - let list = vec![1, 2, 3, 4]; - let mut file_linked_list = FileLinked::new(list, &p)?; - assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4]); + cleanup + .run(|p| async move { + let list = vec![1, 2, 3, 4]; + let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json) + .await + .expect("Unable to create file linked object"); + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; - file_linked_list.mutate(|v1| v1.push(5))?; - assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4, 5]); + assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]); - file_linked_list.mutate(|v1| v1[1] = 1)?; - assert_eq!(*file_linked_list.readonly(), vec![1, 1, 3, 4, 5]); + drop(file_linked_list_ref); + file_linked_list + .mutate(|v1| v1.push(5)) + .await + .expect("Error mutating file linked object"); + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; - drop(file_linked_list); - Ok(()) - }) + assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4, 5]); + + drop(file_linked_list_ref); + file_linked_list + .mutate(|v1| v1[1] = 1) + .await + .expect("Error mutating file linked object"); + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; + + assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]); + + drop(file_linked_list); + }) + .await; } - #[test] - fn test_replace() -> Result<(), Error> { + #[tokio::test] + async fn test_async_mutate() { + let path = PathBuf::from("test_async_mutate"); + let cleanup = CleanUp::new(&path); + cleanup + .run(|p| async move { + let list = vec![1, 2, 3, 4]; + let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json) + .await + .expect("Unable to create file linked object"); + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; + + assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]); + + drop(file_linked_list_ref); + file_linked_list + .mutate_async(|v1| async move { + let mut v = v1.write().await; + v.push(5); + v[1] = 1; + Ok::<(), Error>(()) + }) + .await + .expect("Error mutating file linked object") + .expect("Error mutating file linked object"); + + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; + + assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]); + + drop(file_linked_list); + }) + .await; + } + + #[tokio::test] + async fn test_replace() { let path = PathBuf::from("test_replace"); let cleanup = CleanUp::new(&path); - cleanup.run(|p| { - let val1 = String::from("val1"); - let val2 = String::from("val2"); - let mut file_linked_list = FileLinked::new(val1.clone(), &p)?; - assert_eq!(*file_linked_list.readonly(), val1); + cleanup + .run(|p| async move { + let val1 = String::from("val1"); + let val2 = String::from("val2"); + let mut file_linked_list = FileLinked::new(val1.clone(), &p, DataFormat::Bincode) + .await + .expect("Unable to create file linked object"); + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; - file_linked_list.replace(val2.clone())?; - assert_eq!(*file_linked_list.readonly(), val2); + assert_eq!(*file_linked_list_ref, val1); - drop(file_linked_list); - Ok(()) - }) + file_linked_list + .replace(val2.clone()) + .await + .expect("Error replacing file linked object"); + let file_linked_list_arc = file_linked_list.readonly(); + let file_linked_list_ref = file_linked_list_arc.read().await; + + assert_eq!(*file_linked_list_ref, val2); + + drop(file_linked_list); + }) + .await; } - #[test] - fn test_from_file() -> Result<(), Error> { + #[tokio::test] + async fn test_from_file() { let path = PathBuf::from("test_from_file"); let cleanup = CleanUp::new(&path); - cleanup.run(|p| { - let value: Vec = vec![2.0, 3.0, 5.0]; - let file = File::create(&p)?; + cleanup + .run(|p| async move { + let value: Vec = vec![2.0, 3.0, 5.0]; + let file = File::create(&p).expect("Unable to create file"); - bincode::serialize_into(&file, &value).expect("Unable to serialize into file"); - drop(file); + bincode::serialize_into(&file, &value).expect("Unable to serialize into file"); + drop(file); - let linked_object: FileLinked> = FileLinked::from_file(&p)?; - assert_eq!(*linked_object.readonly(), value); + let linked_object: FileLinked> = + FileLinked::from_file(&p, DataFormat::Bincode) + .expect("Unable to create file linked object"); + let linked_object_arc = linked_object.readonly(); + let linked_object_ref = linked_object_arc.read().await; - drop(linked_object); - Ok(()) - }) + assert_eq!(*linked_object_ref, value); + + drop(linked_object); + }) + .await; } } diff --git a/sbcl_spike/.gitignore b/gemla/.cargo/config.toml similarity index 100% rename from sbcl_spike/.gitignore rename to gemla/.cargo/config.toml diff --git a/gemla/Cargo.toml b/gemla/Cargo.toml index 77c42e1..4208653 100644 --- a/gemla/Cargo.toml +++ b/gemla/Cargo.toml @@ -15,18 +15,22 @@ categories = ["simulation"] [dependencies] serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -uuid = { version = "0.8", features = ["serde", "v4"] } -clap = { version = "~2.27.0", features = ["yaml"] } -toml = "0.5.8" +uuid = { version = "1.7", features = ["serde", "v4"] } +clap = { version = "4.5.2", features = ["derive"] } +toml = "0.8.10" regex = "1" file_linked = { version = "0.1.0", path = "../file_linked" } thiserror = "1.0" anyhow = "1.0" -rand = "0.8.4" -log = "0.4.14" -env_logger = "0.9.0" -futures = "0.3.17" -smol = "1.2.5" -smol-potat = "1.1.2" -num_cpus = "1.13.0" -easy-parallel = "3.1.0" \ No newline at end of file +rand = "0.8.5" +log = "0.4.21" +env_logger = "0.11.3" +futures = "0.3.30" +tokio = { version = "1.37.0", features = ["full"] } +num_cpus = "1.16.0" +easy-parallel = "3.3.1" +fann = "0.1.8" +async-trait = "0.1.78" +async-recursion = "1.1.0" +lerp = "0.5.0" +console-subscriber = "0.2.0" diff --git a/gemla/build.rs b/gemla/build.rs new file mode 100644 index 0000000..e6b8ca6 --- /dev/null +++ b/gemla/build.rs @@ -0,0 +1,11 @@ +fn main() { + // Replace this with the path to the directory containing `fann.lib` + let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib"; + + println!("cargo:rustc-link-search=native={}", lib_dir); + println!("cargo:rustc-link-lib=static=fann"); + // Use `dylib=fann` instead of `static=fann` if you're linking dynamically + + // If there are any additional directories where the compiler can find header files, you can specify them like this: + // println!("cargo:include={}", path_to_include_directory); +} diff --git a/gemla/cli.yml b/gemla/cli.yml deleted file mode 100644 index bb8a0cc..0000000 --- a/gemla/cli.yml +++ /dev/null @@ -1,9 +0,0 @@ -name: GEMLA -version: "0.1" -autor: Jacob VanDomelen -about: Uses a genetic algorithm to generate a machine learning algorithm. -args: - - FILE: - help: Sets the input/output file for the program. - required: true - index: 1 \ No newline at end of file diff --git a/gemla/nodes.toml b/gemla/nodes.toml deleted file mode 100644 index 976061d..0000000 --- a/gemla/nodes.toml +++ /dev/null @@ -1,15 +0,0 @@ -[[nodes]] -fabric_addr = "10.0.0.1:9999" -bridge_bind = "10.0.0.1:8888" -mem = "100 GiB" -cpu = 8 - -# [[nodes]] -# fabric_addr = "10.0.0.2:9999" -# mem = "100 GiB" -# cpu = 16 - -# [[nodes]] -# fabric_addr = "10.0.0.3:9999" -# mem = "100 GiB" -# cpu = 16 \ No newline at end of file diff --git a/gemla/src/bin/bin.rs b/gemla/src/bin/bin.rs index 69bfdac..96248cb 100644 --- a/gemla/src/bin/bin.rs +++ b/gemla/src/bin/bin.rs @@ -1,74 +1,72 @@ -#[macro_use] extern crate clap; extern crate gemla; #[macro_use] extern crate log; +mod fighter_nn; mod test_state; -use anyhow::anyhow; -use clap::App; -use easy_parallel::Parallel; +use anyhow::Result; +use clap::Parser; +use fighter_nn::FighterNN; +use file_linked::constants::data_format::DataFormat; use gemla::{ - constants::args::FILE, core::{Gemla, GemlaConfig}, - error::{log_error, Error}, + error::log_error, }; -use smol::{channel, channel::RecvError, future, Executor}; use std::{path::PathBuf, time::Instant}; -use test_state::TestState; + +// const NUM_THREADS: usize = 2; + +#[derive(Parser)] +#[command(version, about, long_about = None)] +struct Args { + /// The file to read/write the dataset from/to. + #[arg(short, long)] + file: String, +} /// Runs a simluation of a genetic algorithm against a dataset. /// /// Use the -h, --h, or --help flag to see usage syntax. /// TODO -fn main() -> anyhow::Result<()> { +fn main() -> Result<()> { env_logger::init(); - info!("Starting"); + // console_subscriber::init(); + info!("Starting"); let now = Instant::now(); - // Obtainning number of threads to use - let num_threads = num_cpus::get().max(1); - let ex = Executor::new(); - let (signal, shutdown) = channel::unbounded::<()>(); + // Manually configure the Tokio runtime + let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_cpus::get()) + // .worker_threads(NUM_THREADS) + .build()? + .block_on(async { + let args = Args::parse(); // Assuming Args::parse() doesn't need to be async + let mut gemla = log_error( + Gemla::::new( + &PathBuf::from(args.file), + GemlaConfig { overwrite: false }, + DataFormat::Json, + ) + .await, + )?; - // Create an executor thread pool. - let (_, result): (Vec>, Result<(), Error>) = Parallel::new() - .each(0..num_threads, |_| { - future::block_on(ex.run(shutdown.recv())) - }) - .finish(|| { - smol::block_on(async { - drop(signal); + // let gemla_arc = Arc::new(gemla); - // Command line arguments are parsed with the clap crate. And this program uses - // the yaml method with clap. - let yaml = load_yaml!("../../cli.yml"); - let matches = App::from_yaml(yaml).get_matches(); + // Setup your application logic here + // If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks - // Checking that the first argument is a valid file - if let Some(file_path) = matches.value_of(FILE) { - let mut gemla = log_error(Gemla::::new( - &PathBuf::from(file_path), - GemlaConfig { - generations_per_node: 3, - overwrite: true, - }, - ))?; - - log_error(gemla.simulate(3).await)?; - - Ok(()) - } else { - Err(Error::Other(anyhow!("Invalid argument for FILE"))) - } - }) + // Example placeholder loop to continuously run simulate + loop { + // Arbitrary loop count for demonstration + gemla.simulate(1).await?; + } }); - result?; + runtime?; // Handle errors from the block_on call info!("Finished in {:?}", now.elapsed()); - Ok(()) } diff --git a/gemla/src/bin/fighter_nn/fighter_context.rs b/gemla/src/bin/fighter_nn/fighter_context.rs new file mode 100644 index 0000000..56c328f --- /dev/null +++ b/gemla/src/bin/fighter_nn/fighter_context.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use serde::ser::SerializeTuple; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use tokio::sync::Semaphore; + +const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50; +const VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT: usize = 1; + +#[derive(Debug, Clone)] +pub struct FighterContext { + pub shared_semaphore: Arc, + pub visible_simulations: Arc, +} + +impl Default for FighterContext { + fn default() -> Self { + FighterContext { + shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)), + visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)), + } + } +} + +// Custom serialization to just output the concurrency limit. +impl Serialize for FighterContext { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + // Assuming the semaphore's available permits represent the concurrency limit. + // This part is tricky since Semaphore does not expose its initial permits. + // You might need to store the concurrency limit as a separate field if this assumption doesn't hold. + let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT; + let visible_concurrency_limit = VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT; + // serializer.serialize_u64(concurrency_limit as u64) + + // Serialize the concurrency limit as a tuple + let mut state = serializer.serialize_tuple(2)?; + state.serialize_element(&concurrency_limit)?; + state.serialize_element(&visible_concurrency_limit)?; + state.end() + } +} + +// Custom deserialization to reconstruct the FighterContext from a concurrency limit. +impl<'de> Deserialize<'de> for FighterContext { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Deserialize the tuple + let (_, _) = <(usize, usize)>::deserialize(deserializer)?; + Ok(FighterContext { + shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)), + visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serialization() { + let context = FighterContext::default(); + let serialized = serde_json::to_string(&context).unwrap(); + let deserialized: FighterContext = serde_json::from_str(&serialized).unwrap(); + assert_eq!( + context.shared_semaphore.available_permits(), + deserialized.shared_semaphore.available_permits() + ); + assert_eq!( + context.visible_simulations.available_permits(), + deserialized.visible_simulations.available_permits() + ); + } +} diff --git a/gemla/src/bin/fighter_nn/mod.rs b/gemla/src/bin/fighter_nn/mod.rs new file mode 100644 index 0000000..4c4bd98 --- /dev/null +++ b/gemla/src/bin/fighter_nn/mod.rs @@ -0,0 +1,1631 @@ +extern crate fann; + +pub mod fighter_context; +pub mod neural_network_utility; + +use anyhow::{anyhow, Context}; +use async_trait::async_trait; +use fann::{ActivationFunc, Fann}; +use futures::future::join_all; +use gemla::{ + core::genetic_node::{GeneticNode, GeneticNodeContext}, + error::Error, +}; +use lerp::Lerp; +use rand::prelude::*; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::{ + cmp::max, + fs::{self, File}, + io::{self, BufRead, BufReader}, + ops::Range, + path::{Path, PathBuf}, +}; +use tokio::process::Command; +use uuid::Uuid; + +use self::neural_network_utility::{crossbreed, major_mutation}; + +const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations"; +const POPULATION: usize = 200; + +const NEURAL_NETWORK_INPUTS: usize = 22; +const NEURAL_NETWORK_OUTPUTS: usize = 8; + +const NEURAL_NETWORK_HIDDEN_LAYERS_MIN: usize = 1; +const NEURAL_NETWORK_HIDDEN_LAYERS_MAX: usize = 2; + +const NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN: usize = 3; +const NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX: usize = 50; + +const NEURAL_NETWORK_INITIAL_WEIGHT_MAX: f32 = 0.5; + +const NEURAL_NETWORK_MINOR_MUTATION_RATE_MAX: f32 = 0.3; +const NEURAL_NETWORK_MUTATION_WEIGHT_MAX: f32 = 1.0; + +const NEURAL_NETWORK_MAJOR_MUTATION_RATE_MAX: f32 = 1.0; + +const NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN: usize = 2; +const NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX: usize = 6; + +const OFFSHOOT_GENERATIONAL_LENIENCE: u64 = 10; +const MAINLINE_GENERATIONAL_LENIENCE: u64 = 20; + +const SIMULATION_ROUNDS: usize = 5; +const SURVIVAL_RATE_MIN: f32 = 0.1; +const SURVIVAL_RATE_MAX: f32 = 0.9; +const GAME_EXECUTABLE_PATH: &str = + "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe"; + +// Here is the folder structure for the FighterNN: +// base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net + +// A neural network that utilizes the fann library to save and read nn's from files +// FighterNN contains a list of file locations for the nn's stored, all of which are stored under the same folder which is also contained. +// there is no training happening to the neural networks +// the neural networks are only used to simulate the nn's and to save and read the nn's from files +// Filenames are stored in the format of "{fighter_id}_fighter_nn_{generation}.net". +// The main folder contains a subfolder for each generation, containing a population of 10 nn's + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct FighterNN { + pub id: Uuid, + pub folder: PathBuf, + pub population_size: usize, + pub generation: u64, + // A map of each nn identifier in a generation and their physics score + pub scores: Vec>, + // A map of the id of the nn in the current generation and their neural network shape + pub nn_shapes: Vec>>, + pub crossbreed_segments: usize, + pub weight_initialization_range: Range, + pub minor_mutation_rate: f32, + pub major_mutation_rate: f32, + pub mutation_weight_range: Range, + // Shows how individuals are mapped from one generation to the next + pub id_mapping: Vec>, + pub lerp_amount: f32, + pub generational_lenience: u64, + pub survival_rate: f32, +} + +#[async_trait] +impl GeneticNode for FighterNN { + type Context = fighter_context::FighterContext; + + // Check for the highest number of the folder name and increment it by 1 + async fn initialize(context: GeneticNodeContext) -> Result, Error> { + let base_path = PathBuf::from(BASE_DIR); + + let folder = base_path.join(format!("fighter_nn_{:06}", context.id)); + // Ensures directory is created if it doesn't exist and does nothing if it exists + fs::create_dir_all(&folder) + .with_context(|| format!("Failed to create or access the folder: {:?}", folder))?; + + //Create a new directory for the first generation, using create_dir_all to avoid errors if it already exists + let gen_folder = folder.join("0"); + fs::create_dir_all(&gen_folder).with_context(|| { + format!( + "Failed to create or access the generation folder: {:?}", + gen_folder + ) + })?; + + let mut nn_shapes = HashMap::new(); + let weight_initialization_amplitude = thread_rng().gen_range(0.0..NEURAL_NETWORK_INITIAL_WEIGHT_MAX); + let weight_initialization_range = -weight_initialization_amplitude..weight_initialization_amplitude; + + // Create the first generation in this folder + for i in 0..POPULATION { + // Filenames are stored in the format of "xxxxxx_fighter_nn_0.net", "xxxxxx_fighter_nn_1.net", etc. Where xxxxxx is the folder name + let nn = gen_folder + .join(format!("{:06}_fighter_nn_{}", context.id, i)) + .with_extension("net"); + + // Randomly generate a neural network shape based on constants + let hidden_layers = thread_rng() + .gen_range(NEURAL_NETWORK_HIDDEN_LAYERS_MIN..=NEURAL_NETWORK_HIDDEN_LAYERS_MAX); + let mut nn_shape = vec![NEURAL_NETWORK_INPUTS as u32]; + for _ in 0..hidden_layers { + nn_shape.push(thread_rng().gen_range( + NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN..=NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX, + ) as u32); + } + nn_shape.push(NEURAL_NETWORK_OUTPUTS as u32); + nn_shapes.insert(i as u64, nn_shape.clone()); + + let mut fann = Fann::new(nn_shape.as_slice()).with_context(|| "Failed to create nn")?; + fann.randomize_weights( + weight_initialization_range.start, + weight_initialization_range.end, + ); + fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); + fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); + // This will overwrite any existing file with the same name + fann.save(&nn) + .with_context(|| format!("Failed to save nn at {:?}", nn))?; + } + + let mut crossbreed_segments = thread_rng().gen_range( + NEURAL_NETWORK_CROSSBREED_SEGMENTS_MIN..=NEURAL_NETWORK_CROSSBREED_SEGMENTS_MAX, + ); + if crossbreed_segments % 2 == 0 { + crossbreed_segments += 1; + } + + let mutation_weight_amplitude = thread_rng().gen_range(0.0..NEURAL_NETWORK_MUTATION_WEIGHT_MAX); + + Ok(Box::new(FighterNN { + id: context.id, + folder, + population_size: POPULATION, + generation: 0, + scores: vec![], + nn_shapes: vec![nn_shapes], + // we need crossbreed segments to be even + crossbreed_segments, + weight_initialization_range, + minor_mutation_rate: thread_rng().gen_range(0.0..NEURAL_NETWORK_MINOR_MUTATION_RATE_MAX), + major_mutation_rate: thread_rng().gen_range(0.0..NEURAL_NETWORK_MAJOR_MUTATION_RATE_MAX), + mutation_weight_range: -mutation_weight_amplitude..mutation_weight_amplitude, + id_mapping: vec![], + lerp_amount: 0.0, + generational_lenience: OFFSHOOT_GENERATIONAL_LENIENCE, + survival_rate: thread_rng().gen_range(SURVIVAL_RATE_MIN..SURVIVAL_RATE_MAX), + })) + } + + async fn simulate( + &mut self, + context: GeneticNodeContext, + ) -> Result { + debug!("Context: {:?}", context); + let mut matches = Vec::new(); + let mut allotted_simulations = Vec::new(); + for i in 0..self.population_size { + allotted_simulations.push((i, SIMULATION_ROUNDS)); + } + + while !allotted_simulations.is_empty() { + let primary_id = { + let id = thread_rng().gen_range(0..allotted_simulations.len()); + let (i, _) = allotted_simulations[id]; + // Decrement the number of simulations left for this nn + allotted_simulations[id].1 -= 1; + // Remove the nn from the list if it has no more simulations left + if allotted_simulations[id].1 == 0 { + allotted_simulations.remove(id); + } + i + }; + + + let secondary_id = loop { + if allotted_simulations.is_empty() || allotted_simulations.len() == 1 { + // Select a random id + let random_id = loop { + let id = thread_rng().gen_range(0..self.population_size); + if id != primary_id { + allotted_simulations.clear(); + break id; + } + }; + + break random_id; + } + + let id = thread_rng().gen_range(0..allotted_simulations.len()); + let (i, _) = allotted_simulations[id]; + + if i != primary_id { + // Decrement the number of simulations left for this nn + allotted_simulations[id].1 -= 1; + // Remove the nn from the list if it has no more simulations left + if allotted_simulations[id].1 == 0 { + allotted_simulations.remove(id); + } + break i; + } + }; + + matches.push((primary_id, secondary_id)); + } + + debug!("Matches determined"); + trace!("Matches: {:?}", matches); + + // Create a channel to send the scores back to the main thread + let mut tasks = Vec::new(); + + for (primary_id, secondary_id) in matches.iter() { + let task = { + let self_clone = self.clone(); + let semaphore_clone = context.gemla_context.shared_semaphore.clone(); + let display_simulation_semaphore = context.gemla_context.visible_simulations.clone(); + + let folder = self_clone.folder.clone(); + let generation = self_clone.generation; + + let primary_nn = self_clone + .folder + .join(format!("{}", self_clone.generation)) + .join(self_clone.get_individual_id(*primary_id as u64)) + .with_extension("net"); + let secondary_nn = folder + .join(format!("{}", generation)) + .join(self_clone.get_individual_id(*secondary_id as u64)) + .with_extension("net"); + + // Introducing a new scope for acquiring permits and running simulations + let simulation_result = async move { + let permit = semaphore_clone.acquire_owned().await + .with_context(|| "Failed to acquire semaphore permit")?; + + let display_simulation = match display_simulation_semaphore.try_acquire_owned() { + Ok(s) => Some(s), + Err(_) => None, + }; + + let (primary_score, secondary_score) = if let Some(display_simulation) = display_simulation { + let result = run_1v1_simulation(&primary_nn, &secondary_nn, true).await?; + drop(display_simulation); // Explicitly dropping resources no longer needed + result + } else { + run_1v1_simulation(&primary_nn, &secondary_nn, false).await? + }; + + drop(permit); // Explicitly dropping resources no longer needed + + debug!( + "{} vs {} -> {} vs {}", + primary_id, secondary_id, primary_score, secondary_score + ); + + Ok((*primary_id, primary_score, *secondary_id, secondary_score)) + }; // Await the scoped async block immediately + + // The result of the simulation, whether Ok or Err, is returned here. + // This ensures tx is dropped when the block exits, regardless of success or failure. + simulation_result + }; + + tasks.push(task); + } + + debug!("Tasks created"); + + let results: Vec> = join_all(tasks).await; + + debug!("Tasks completed"); + + // resolve results for any errors + let mut scores = HashMap::new(); + for result in results.into_iter() { + let (primary_id, primary_score, secondary_id, secondary_score) = result.with_context(|| "Failed to run simulation")?; + + // If score exists, add the new score to the existing score + if let Some((existing_score, count)) = scores.get_mut(&(primary_id as u64)) { + *existing_score += primary_score; + *count += 1; + } else { + scores.insert(primary_id as u64, (primary_score, 1)); + } + + // If score exists, add the new score to the existing score + if let Some((existing_score, count)) = scores.get_mut(&(secondary_id as u64)) { + *existing_score += secondary_score; + *count += 1; + } else { + scores.insert(secondary_id as u64, (secondary_score, 1)); + } + } + + // Average scores for each individual + let mut final_scores = HashMap::new(); + for (i, (score, count)) in scores.iter() { + final_scores.insert(*i, *score / *count as f32); + } + + self.scores.push(final_scores); + + Ok(should_continue(&self.scores, self.generational_lenience)?) + } + + async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { + let survivor_count = (self.population_size as f32 * self.survival_rate) as usize; + let mut nn_sizes = Vec::new(); + let mut id_mapping = HashMap::new(); + + // Create the new generation folder + let new_gen_folder = self.folder.join(format!("{}", self.generation + 1)); + fs::create_dir_all(&new_gen_folder).with_context(|| { + format!( + "Failed to create or access new generation folder: {:?}", + new_gen_folder + ) + })?; + + // Remove the 5 nn's with the lowest scores + let mut sorted_scores: Vec<_> = self.scores[self.generation as usize].iter().collect(); + sorted_scores.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + let scores_to_keep: Vec<&(&u64, &f32)> = + sorted_scores.iter().take(survivor_count).collect(); + let to_keep = scores_to_keep.iter().map(|(k, _)| *k).collect::>(); + + // Save the remaining 5 nn's to the new generation folder + for (i, nn_id) in to_keep.iter().enumerate().take(survivor_count) { + let nn = self + .folder + .join(format!("{}", self.generation)) + .join(format!("{:06}_fighter_nn_{}.net", self.id, nn_id)); + let new_nn = new_gen_folder.join(format!("{:06}_fighter_nn_{}.net", self.id, i)); + debug!("Copying nn from {:?} to {:?}", nn_id, i); + id_mapping.insert(**nn_id, i as u64); + fs::copy(&nn, &new_nn)?; + nn_sizes.push( + self.nn_shapes[self.generation as usize] + .get(nn_id) + .unwrap() + .clone(), + ); + } + + let weights: HashMap = scores_to_keep.iter().map(|(k, v)| (**k, **v)).collect(); + + debug!("scores: {:?}", scores_to_keep); + + let mut tasks = Vec::new(); + + // Take the remaining nn's and create new nn's by the following: + for i in 0..(self.population_size - survivor_count) { + let self_clone = self.clone(); + + // randomly select individual id's sorted scores proportional to their score + let nn_id = weighted_random_selection(&weights); + let nn = self_clone + .folder + .join(format!("{}", self_clone.generation)) + .join(self_clone.get_individual_id(nn_id)) + .with_extension("net"); + + // Load another nn from the current generation and cross breed it with the current nn + let cross_id = loop { + let cross_id = weighted_random_selection(&weights); + if cross_id != nn_id { + break cross_id; + } + }; + + let cross_nn = self_clone + .folder + .join(format!("{}", self_clone.generation)) + .join(self_clone.get_individual_id(cross_id)) + .with_extension("net"); + + let new_gen_folder = new_gen_folder.clone(); + + let future = tokio::task::spawn_blocking(move || -> Result, Error> { + let fann = Fann::from_file(&nn).with_context(|| "Failed to load nn")?; + let cross_fann = + Fann::from_file(&cross_nn).with_context(|| "Failed to load cross nn")?; + + let mut new_fann = crossbreed( + &self_clone, + &fann, + &cross_fann, + self_clone.crossbreed_segments, + )?; + + // For each weight in the 5 new nn's there is a 20% chance of a minor mutation (a random number between -0.1 and 0.1 is added to the weight) + // And a 5% chance of a major mutation a new neuron is randomly added to a hidden layer + let mut connections = new_fann.get_connections(); // Vector of connections + for c in &mut connections { + if thread_rng().gen_range(0.0..1.0) < self_clone.minor_mutation_rate { + trace!("Minor mutation on connection {:?}", c); + c.weight += + thread_rng().gen_range(self_clone.weight_initialization_range.clone()); + trace!("New weight: {}", c.weight); + } + } + + new_fann.set_connections(&connections); + + if thread_rng().gen_range(0.0..1.0) < self_clone.major_mutation_rate { + new_fann = + major_mutation(&new_fann, self_clone.weight_initialization_range.clone())?; + } + + let new_nn = new_gen_folder + .join(self_clone.get_individual_id((i + survivor_count) as u64)) + .with_extension("net"); + new_fann.save(new_nn).with_context(|| "Failed to save nn")?; + + Ok::, Error>(new_fann.get_layer_sizes()) + }); + + tasks.push(future); + } + + let results = join_all(tasks).await; + + for result in results.into_iter() { + let new_size = result.with_context(|| "Failed to create new nn")??; + nn_sizes.push(new_size); + } + + // Use the index of nn_sizes to generate the id for the nn_sizes HashMap + let nn_sizes_map = nn_sizes + .into_iter() + .enumerate() + .map(|(i, v)| (i as u64, v)) + .collect::>(); + + self.generation += 1; + self.nn_shapes.push(nn_sizes_map); + self.id_mapping.push(id_mapping); + + Ok(()) + } + + async fn merge( + left: &FighterNN, + right: &FighterNN, + id: &Uuid, + gemla_context: Self::Context, + ) -> Result, Error> { + let base_path = PathBuf::from(BASE_DIR); + let folder = base_path.join(format!("fighter_nn_{:06}", id)); + + // Ensure the folder exists, including the generation subfolder. + fs::create_dir_all(folder.join("0")) + .with_context(|| format!("Failed to create directory {:?}", folder.join("0")))?; + + let get_highest_scores = |fighter: &FighterNN| -> Vec<(u64, f32)> { + let mut sorted_scores: Vec<_> = + fighter.scores[fighter.generation as usize].iter().collect(); + sorted_scores.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap()); + sorted_scores + .iter() + .take(fighter.population_size / 2) + .map(|(k, v)| (**k, **v)) + .collect() + }; + + let left_scores = get_highest_scores(left); + let right_scores = get_highest_scores(right); + + debug!("Left scores: {:?}", left_scores); + debug!("Right scores: {:?}", right_scores); + + let mut simulations = Vec::new(); + + let left_weights: HashMap = left_scores.iter().map(|(k, v)| (*k, *v)).collect(); + let right_weights: HashMap = right_scores.iter().map(|(k, v)| (*k, *v)).collect(); + + let num_simulations = max(left.population_size, right.population_size) * SIMULATION_ROUNDS; + + for _ in 0..num_simulations { + let left_nn_id = weighted_random_selection(&left_weights); + let right_nn_id = weighted_random_selection(&right_weights); + + let left_nn_path = left + .folder + .join(left.generation.to_string()) + .join(left.get_individual_id(left_nn_id)) + .with_extension("net"); + let right_nn_path = right + .folder + .join(right.generation.to_string()) + .join(right.get_individual_id(right_nn_id)) + .with_extension("net"); + let semaphore_clone = gemla_context.shared_semaphore.clone(); + let display_simulation_semaphore = gemla_context.visible_simulations.clone(); + + let future = async move { + let permit = semaphore_clone + .acquire_owned() + .await + .with_context(|| "Failed to acquire semaphore permit")?; + + let display_simulation = match display_simulation_semaphore.try_acquire_owned() { + Ok(s) => Some(s), + Err(_) => None, + }; + + let (left_score, right_score) = if let Some(display_simulation) = display_simulation + { + let result = run_1v1_simulation(&left_nn_path, &right_nn_path, true).await?; + drop(display_simulation); + result + } else { + run_1v1_simulation(&left_nn_path, &right_nn_path, false).await? + }; + + debug!("{} vs {} -> {} vs {}", left_nn_id, right_nn_id, left_score, right_score); + + drop(permit); + + Ok::<(f32, f32), Error>((left_score, right_score)) + }; + + simulations.push(future); + } + + let results: Result, Error> = + join_all(simulations).await.into_iter().collect(); + let scores = results?; + + let total_left_score = scores.iter().map(|(l, _)| l).sum::() / num_simulations as f32; + let total_right_score = scores.iter().map(|(_, r)| r).sum::() / num_simulations as f32; + + debug!("Total left score: {}", total_left_score); + debug!("Total right score: {}", total_right_score); + + let score_difference = total_right_score - total_left_score; + // Use the sigmoid function to determine lerp amount + let lerp_amount = 1.0 / (1.0 + (-score_difference).exp()); + + debug!("Lerp amount: {}", lerp_amount); + + let mut nn_shapes = HashMap::new(); + + // Function to copy NNs from a source FighterNN to the new folder. + let mut copy_nns = |source: &FighterNN, + folder: &PathBuf, + id: &Uuid, + start_idx: usize| + -> Result<(), Error> { + let mut sorted_scores: Vec<_> = + source.scores[source.generation as usize].iter().collect(); + sorted_scores.sort_by(|a, b| a.1.partial_cmp(b.1).unwrap()); + let remaining = sorted_scores[(source.population_size / 2)..] + .iter() + .map(|(k, _)| *k) + .collect::>(); + + for (i, nn_id) in remaining.into_iter().enumerate() { + let nn_path = source + .folder + .join(source.generation.to_string()) + .join(format!("{:06}_fighter_nn_{}.net", source.id, nn_id)); + let new_nn_path = + folder + .join("0") + .join(format!("{:06}_fighter_nn_{}.net", id, start_idx + i)); + fs::copy(&nn_path, &new_nn_path).with_context(|| { + format!("Failed to copy nn from {:?} to {:?}", nn_path, new_nn_path) + })?; + + let nn_shape = source.nn_shapes[source.generation as usize] + .get(nn_id) + .unwrap(); + + nn_shapes.insert((start_idx + i) as u64, nn_shape.clone()); + } + + Ok(()) + }; + + // Copy the top half of NNs from each parent to the new folder. + copy_nns(left, &folder, id, 0)?; + copy_nns(right, &folder, id, left.population_size / 2)?; + + debug!("nn_shapes: {:?}", nn_shapes); + + // Lerp the mutation rates and weight ranges + let crossbreed_segments = (left.crossbreed_segments as f32) + .lerp(right.crossbreed_segments as f32, lerp_amount) + as usize; + + let weight_initialization_range_start = left + .weight_initialization_range + .start + .lerp(right.weight_initialization_range.start, lerp_amount); + let weight_initialization_range_end = left + .weight_initialization_range + .end + .lerp(right.weight_initialization_range.end, lerp_amount); + // Have to ensure the range is valid + let weight_initialization_range = + if weight_initialization_range_start < weight_initialization_range_end { + weight_initialization_range_start..weight_initialization_range_end + } else { + weight_initialization_range_end..weight_initialization_range_start + }; + + debug!( + "weight_initialization_range: {:?}", + weight_initialization_range + ); + + let minor_mutation_rate = left + .minor_mutation_rate + .lerp(right.minor_mutation_rate, lerp_amount); + let major_mutation_rate = left + .major_mutation_rate + .lerp(right.major_mutation_rate, lerp_amount); + + debug!("minor_mutation_rate: {}", minor_mutation_rate); + debug!("major_mutation_rate: {}", major_mutation_rate); + + let mutation_weight_range_start = left + .mutation_weight_range + .start + .lerp(right.mutation_weight_range.start, lerp_amount); + let mutation_weight_range_end = left + .mutation_weight_range + .end + .lerp(right.mutation_weight_range.end, lerp_amount); + // Have to ensure the range is valid + let mutation_weight_range = if mutation_weight_range_start < mutation_weight_range_end { + mutation_weight_range_start..mutation_weight_range_end + } else { + mutation_weight_range_end..mutation_weight_range_start + }; + + debug!("mutation_weight_range: {:?}", mutation_weight_range); + + let survival_rate = left.survival_rate.lerp(right.survival_rate, lerp_amount); + + debug!("survival_rate: {}", survival_rate); + + Ok(Box::new(FighterNN { + id: *id, + folder, + generation: 0, + population_size: nn_shapes.len(), + scores: vec![], + crossbreed_segments, + nn_shapes: vec![nn_shapes], + weight_initialization_range, + minor_mutation_rate, + major_mutation_rate, + mutation_weight_range, + id_mapping: vec![], + lerp_amount, + // generational_lenience: left.generational_lenience + MAINLINE_GENERATIONAL_LENIENCE, + generational_lenience: MAINLINE_GENERATIONAL_LENIENCE, + survival_rate, + })) + } +} + +impl FighterNN { + pub fn get_individual_id(&self, nn_id: u64) -> String { + format!("{:06}_fighter_nn_{}", self.id, nn_id) + } +} + +fn should_continue(scores: &[HashMap], lenience: u64) -> Result { + if scores.len() < lenience as usize { + return Ok(true); + } + + let mut highest_q3_value = f32::MIN; + let mut generation_with_highest_q3 = 0; + + let mut highest_median = f32::MIN; + let mut generation_with_highest_median = 0; + + for (generation_index, generation) in scores.iter().enumerate() { + let mut scores: Vec = generation.values().copied().collect(); + scores.sort_by(|a, b| a.partial_cmp(b).unwrap()); + + let q3_index = (scores.len() as f32 * 0.75).ceil() as usize - 1; + let q3_value = scores + .get(q3_index) + .ok_or(anyhow!("Failed to get Q3 value"))?; + + if *q3_value > highest_q3_value { + highest_q3_value = *q3_value; + generation_with_highest_q3 = generation_index; + } + + let median_index = (scores.len() as f32 * 0.5).ceil() as usize - 1; + let median_value = scores + .get(median_index) + .ok_or(anyhow!("Failed to get median value"))?; + + if *median_value > highest_median { + highest_median = *median_value; + generation_with_highest_median = generation_index; + } + } + + let highest_generation_index = scores.len() - 1; + let result = highest_generation_index - generation_with_highest_q3 < lenience as usize + && highest_generation_index - generation_with_highest_median < lenience as usize; + + debug!( + "Highest Q3 value: {} at generation {}, Highest Median value: {} at generation {}, Continuing? {}", + highest_q3_value, generation_with_highest_q3 + 1, highest_median, generation_with_highest_median + 1, result + ); + + Ok(result) +} + +fn weighted_random_selection(weights: &HashMap) -> T { + let mut rng = thread_rng(); + + // Identify the minimum weight + let min_weight = weights.values().fold(f32::INFINITY, |a, &b| a.min(b)); + + // Adjust all weights to be non-negative + let offset = if min_weight < 0.0 { + (-min_weight) + 0.5 + } else { + 0.0 + }; + let total_weight: f32 = weights.values().map(|w| w + offset).sum(); + + let mut cumulative_weight = 0.0; + let random_weight = rng.gen::() * total_weight; + + for (item, weight) in weights.iter() { + cumulative_weight += *weight + offset; + if cumulative_weight >= random_weight { + return item.clone(); + } + } + + panic!("Weighted random selection failed."); +} + +async fn run_1v1_simulation( + nn_path_1: &Path, + nn_path_2: &Path, + display_simulation: bool, +) -> Result<(f32, f32), Error> { + // Construct the score file path + let base_folder = nn_path_1.parent().unwrap(); + let nn_1_id = nn_path_1.file_stem().unwrap().to_str().unwrap(); + let nn_2_id = nn_path_2.file_stem().unwrap().to_str().unwrap(); + let score_file = base_folder.join(format!("{}_vs_{}.txt", nn_1_id, nn_2_id)); + + // Check if score file already exists before running the simulation + if score_file.exists() { + let round_score = read_score_from_file(&score_file, nn_1_id) + .await + .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; + + let opposing_score = read_score_from_file(&score_file, nn_2_id) + .await + .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; + + trace!( + "{} scored {}, while {} scored {}", + nn_1_id, round_score, nn_2_id, opposing_score + ); + + return Ok((round_score, opposing_score)); + } + + // Check if the opposite round score has been determined + let opposite_score_file = base_folder.join(format!("{}_vs_{}.txt", nn_2_id, nn_1_id)); + if opposite_score_file.exists() { + let round_score = read_score_from_file(&opposite_score_file, nn_1_id) + .await + .with_context(|| { + format!("Failed to read score from file: {:?}", opposite_score_file) + })?; + + let opposing_score = read_score_from_file(&opposite_score_file, nn_2_id) + .await + .with_context(|| { + format!("Failed to read score from file: {:?}", opposite_score_file) + })?; + + trace!( + "{} scored {}, while {} scored {}", + nn_1_id, round_score, nn_2_id, opposing_score + ); + + return Ok((round_score, opposing_score)); + } + + // Run simulation until score file is generated + let config1_arg = format!("-NN1Config=\"{}\"", nn_path_1.to_str().unwrap()); + let config2_arg = format!("-NN2Config=\"{}\"", nn_path_2.to_str().unwrap()); + let disable_unreal_rendering_arg = "-nullrhi".to_string(); + + trace!( + "Executing the following command {} {} {} {}", + GAME_EXECUTABLE_PATH, + config1_arg, + config2_arg, + disable_unreal_rendering_arg + ); + + trace!("Running simulation for {} vs {}", nn_1_id, nn_2_id); + + let _output = if display_simulation { + Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .output() + .await + .expect("Failed to execute game") + } else { + Command::new(GAME_EXECUTABLE_PATH) + .arg(&config1_arg) + .arg(&config2_arg) + .arg(&disable_unreal_rendering_arg) + .output() + .await + .expect("Failed to execute game") + }; + + trace!( + "Simulation completed for {} vs {}: {}", + nn_1_id, + nn_2_id, + score_file.exists() + ); + + // Read the score from the file + if score_file.exists() { + let round_score = read_score_from_file(&score_file, nn_1_id) + .await + .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; + + let opposing_score = read_score_from_file(&score_file, nn_2_id) + .await + .with_context(|| format!("Failed to read score from file: {:?}", score_file))?; + + trace!( + "{} scored {}, while {} scored {}", + nn_1_id, round_score, nn_2_id, opposing_score + ); + + Ok((round_score, opposing_score)) + } else { + warn!("Score file not found: {:?}", score_file); + Ok((0.0, 0.0)) + } +} + +async fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result { + let mut attempts = 0; + + loop { + match File::open(file_path) { + Ok(file) => { + let reader = BufReader::new(file); + + for line in reader.lines() { + let line = line?; + if line.starts_with(nn_id) { + let parts: Vec<&str> = line.split(':').collect(); + if parts.len() == 2 { + return parts[1] + .trim() + .parse::() + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)); + } + } + } + + return Err(io::Error::new( + io::ErrorKind::NotFound, + "NN ID not found in scores file", + )); + } + Err(_) => + { + if attempts >= 2 { + // Attempt 5 times before giving up. + return Ok(-100.0); + } + + attempts += 1; + // wait 1 second to ensure the file is written + tokio::time::sleep(tokio::time::Duration::from_secs(10)).await; + } + } + } +} + +#[cfg(test)] +pub mod test { + use super::*; + + #[test] + fn test_weighted_random_selection() { + let weights = vec![ + (43, -4.0403514), + (26, -2.9386168), + (44, -2.8106647), + (46, -1.3942022), + (23, 0.99386656), + (41, -2.2198126), + (48, 1.2195103), + (42, -3.4927247), + (7, -1.092067), + (0, -0.3878999), + (49, -4.156101), + (34, -0.33209237), + (30, -2.7059758), + (2, -2.251783), + (20, -0.5811202), + (10, -3.047954), + (6, -4.3464293), + (39, -3.7280478), + (1, -3.4291298), + (11, -2.0568254), + (24, -1.5701149), + (8, -1.5029285), + (3, -2.4728038), + (4, 3.7312133), + (25, -1.227466), + ] + .into_iter() + .collect(); + + let mut ids = vec![ + 43, 26, 44, 46, 23, 41, 48, 42, 7, 0, 49, 34, 30, 2, 20, 10, 6, 39, 1, 11, 24, 8, 3, 4, + 25, + ]; + + for _ in 0..10000 { + let id = weighted_random_selection(&weights); + + ids = ids.into_iter().filter(|&x| x != id).collect(); + + assert!(weights.contains_key(&id)); + } + + assert_eq!(ids.len(), 0); + } + + #[test] + fn test_should_continue() { + let scores = vec![ + // Generation 0 + [ + (37, -7.1222725), + (12, -3.6037624), + (27, -5.202844), + (21, -6.3283415), + (4, -6.0053186), + (8, -4.040202), + (13, -4.0050435), + (17, -5.8206105), + (40, -7.5448103), + (42, -8.027704), + (15, -5.1600137), + (10, -7.9063845), + (1, -6.9830275), + (7, -3.3323112), + (16, -6.1065326), + (23, -6.417853), + (25, -6.410652), + (14, -6.5887403), + (3, -6.3966584), + (19, 0.1242948), + (28, -4.806827), + (18, -6.3310747), + (30, -5.8972425), + (31, -6.398958), + (22, -7.042196), + (29, -5.7098813), + (9, -8.931531), + (33, -5.9806275), + (6, -6.5489874), + (26, -5.892653), + (34, -6.4281516), + (35, -5.5369387), + (38, -5.495344), + (43, 0.9552175), + (44, -6.2549844), + (45, -8.42142), + (24, -7.121878), + (47, -5.373896), + (48, -6.445716), + (39, -6.053849), + (11, -5.8320975), + (49, -10.014197), + (46, -7.0919595), + (20, -6.033137), + (5, -6.3501267), + (32, -4.203919), + (2, -5.743471), + (36, -8.493466), + (41, -7.60419), + (0, -7.388545), + ], + // Generation 1 + [ + (18, -6.048934), + (39, -1.1448132), + (48, -7.921489), + (38, -6.0117235), + (27, -6.30289), + (9, -6.5567093), + (29, -5.905172), + (25, -4.2305975), + (40, -5.1198816), + (24, -7.232001), + (46, -6.5581756), + (20, -6.7987585), + (8, -9.346154), + (2, -7.6944494), + (3, -6.487195), + (16, -8.379641), + (32, -7.292016), + (33, -7.91467), + (41, -7.4449363), + (21, -6.0500197), + (19, -5.357873), + (10, -6.9984064), + (7, -5.6824636), + (13, -8.154273), + (45, -7.8713655), + (47, -5.279138), + (49, -1.915852), + (6, -2.682654), + (30, -5.566201), + (1, -1.829716), + (11, -7.7527223), + (12, -10.379072), + (15, -4.866212), + (35, -8.091223), + (36, -8.137203), + (42, -7.2846284), + (44, -4.7636213), + (28, -6.518874), + (34, 1.9858776), + (43, -10.140268), + (0, -3.5068736), + (17, -2.3913155), + (26, -6.1766686), + (22, -9.119884), + (14, -7.470778), + (5, -5.925585), + (23, -6.004782), + (31, -2.696432), + (4, -2.4887466), + (37, -5.5321026), + ], + // Generation 2 + [ + (25, -8.760574), + (0, -2.5970187), + (9, -4.270929), + (11, -0.27550858), + (20, -6.7012835), + (30, 2.3309054), + (4, -7.0107384), + (31, -7.5239167), + (41, -2.337672), + (6, -3.4384027), + (16, -7.9485044), + (37, -7.3155503), + (38, -7.4812994), + (3, -3.958924), + (42, -7.738173), + (43, -6.500585), + (22, -6.318394), + (17, -5.7882595), + (45, -8.782414), + (49, -8.84129), + (23, -10.222613), + (26, -6.06804), + (32, -6.4851217), + (33, -7.3542376), + (34, -2.8723297), + (27, -7.1350646), + (8, -2.7956052), + (18, -5.0000043), + (10, -1.5138103), + (2, 0.10560961), + (7, -1.4954948), + (35, -7.7015786), + (36, -8.602789), + (47, -8.117584), + (28, -9.151132), + (39, -8.035833), + (13, -6.2601876), + (15, -9.050044), + (19, -5.465233), + (44, -8.494604), + (5, -6.9012084), + (12, -9.458872), + (21, -5.980685), + (14, -7.7407913), + (46, -0.701484), + (24, -9.477325), + (29, -6.6444407), + (1, -3.4681067), + (40, -5.4685316), + (48, 0.22965483), + ], + // Generation 3 + [ + (11, -5.7744265), + (12, 0.10171394), + (18, -8.503949), + (3, -1.9760166), + (17, -7.895561), + (20, -8.515409), + (45, -1.9184738), + (6, -5.6488137), + (46, -6.1171823), + (49, -7.006673), + (29, -3.6479561), + (37, -4.025724), + (42, -4.1281996), + (9, -2.7060657), + (33, 0.18799233), + (15, -7.8216696), + (23, -11.02603), + (22, -10.132984), + (7, -6.432255), + (38, -7.2159233), + (10, -2.195277), + (2, -6.7676725), + (27, -1.8040345), + (34, -11.214028), + (40, -6.1334066), + (35, -9.410227), + (44, -0.14929143), + (47, -7.3865366), + (41, -9.200221), + (26, -6.1885824), + (13, -5.5693216), + (31, -8.184256), + (39, -8.06583), + (24, -11.773471), + (25, -15.231514), + (14, -5.4468412), + (30, -5.494699), + (21, -10.619481), + (28, -7.322004), + (16, -7.4136076), + (8, -3.2260292), + (32, -8.187313), + (19, -5.9347467), + (43, -0.112977505), + (5, -1.9279568), + (48, -3.8396995), + (0, -9.317253), + (4, -1.8099403), + (1, -5.4981036), + (36, -3.5487309), + ], + // Generation 4 + [ + (28, -6.2057357), + (40, -6.9324327), + (46, -0.5130272), + (23, -7.9489794), + (47, -7.3411865), + (20, -8.930363), + (26, -3.238875), + (41, -7.376683), + (48, -0.83026105), + (27, -10.048681), + (36, -5.1788163), + (30, -8.002236), + (9, -7.4656434), + (4, -3.8850121), + (16, -3.1768656), + (11, 1.0195583), + (44, -8.7163315), + (45, -6.7038856), + (33, -6.974304), + (22, -10.026589), + (13, -4.342838), + (12, -6.69588), + (31, -2.2994905), + (14, -7.9772606), + (32, -10.55702), + (38, -5.668454), + (34, -10.026564), + (37, -8.128912), + (42, -10.7178335), + (17, -5.18195), + (49, -9.900299), + (21, -12.4000635), + (8, -1.8514707), + (29, -3.365313), + (39, -5.588918), + (43, -8.482417), + (1, -4.390686), + (35, -5.604909), + (24, -7.1810236), + (25, -5.9158974), + (19, -4.5733366), + (0, -5.68081), + (3, -2.8414884), + (6, -1.5809858), + (7, -9.295659), + (5, -3.7936096), + (10, -4.088697), + (2, -2.3494315), + (15, -7.3323736), + (18, -7.7137175), + ], + // Generation 5 + [ + (1, -2.7719336), + (37, -6.097855), + (39, -4.1296787), + (2, -5.4538774), + (34, -11.808794), + (40, -9.822159), + (3, -7.884645), + (42, -14.777964), + (32, -2.6564443), + (16, -5.2442584), + (9, -6.2919874), + (48, -2.4359574), + (25, -11.707236), + (33, -5.5483084), + (35, -0.3632618), + (7, -4.3673687), + (27, -8.139543), + (12, -9.019396), + (17, -0.029791832), + (24, -8.63045), + (18, -11.925819), + (20, -9.040375), + (44, -10.296264), + (47, -15.95397), + (23, -12.38116), + (21, 0.18342426), + (38, -7.695002), + (6, -8.710346), + (28, -2.8542902), + (5, -2.077858), + (10, -3.638583), + (8, -7.360152), + (15, -7.1610765), + (29, -4.8372035), + (45, -11.499393), + (13, -3.8436065), + (22, -5.472387), + (11, -4.259357), + (26, -4.847328), + (4, -2.0376666), + (36, -7.5392637), + (41, -5.3857164), + (19, -8.576212), + (14, -8.267895), + (30, -4.0456495), + (31, -3.806975), + (43, -7.9901657), + (46, -7.181662), + (0, -7.502816), + (49, -7.3067017), + ], + // Generation 6 + [ + (17, -9.793276), + (27, -2.8843281), + (38, -8.737534), + (8, -1.5083166), + (16, -8.267393), + (42, -8.055011), + (47, -2.0843022), + (14, -3.9945045), + (30, -10.208374), + (26, -3.2439823), + (49, -2.5527742), + (25, -10.359426), + (9, -4.4744225), + (19, -7.2775927), + (3, -7.282045), + (36, -8.503307), + (40, -12.083569), + (22, -3.7249084), + (18, -7.5065627), + (41, -3.3326488), + (44, -2.76882), + (45, -12.154654), + (24, -2.8332536), + (5, -5.2674284), + (4, -4.105483), + (10, -6.930478), + (20, -3.7845988), + (2, -4.4593267), + (28, -0.3003047), + (29, -6.5971193), + (32, -5.0542274), + (33, -9.068264), + (43, -7.124672), + (46, -8.358111), + (23, -5.551978), + (11, -7.7810373), + (35, -7.4763336), + (34, -10.868844), + (39, -10.51066), + (7, -4.376377), + (48, -9.093265), + (6, -0.20033613), + (1, -6.125786), + (12, -8.243349), + (0, -7.1646323), + (13, -3.7055316), + (15, -6.295897), + (21, -5.929867), + (31, -7.2123885), + (37, -2.482071), + ], + // Generation 7 + [ + (30, -12.467585), + (14, -5.1706576), + (40, -9.03964), + (18, -5.7730474), + (41, -9.061858), + (20, -2.8577142), + (24, -3.3558655), + (42, -7.902747), + (43, -6.1566644), + (21, -5.4271364), + (23, -7.1462164), + (44, -7.9898252), + (11, -2.493559), + (31, -4.6718645), + (48, -12.774545), + (8, -7.252562), + (35, -1.6866531), + (49, -4.437603), + (45, -7.164916), + (7, -4.613396), + (32, -8.156101), + (39, -10.887325), + (0, -0.18116185), + (47, -4.998584), + (10, -8.914183), + (13, -0.8690014), + (27, -0.3714923), + (28, -12.002966), + (9, -6.2789965), + (26, -0.46416503), + (2, -9.865377), + (29, -8.443848), + (46, -6.3264246), + (3, -7.807205), + (4, -6.8240366), + (5, -6.843891), + (12, -5.6381693), + (15, -4.6679296), + (36, -6.8010025), + (16, -8.222928), + (25, -10.326822), + (34, -6.0182467), + (37, -8.713378), + (38, -7.549215), + (17, -7.247555), + (22, -13.296148), + (33, -8.542955), + (19, -7.254419), + (1, -2.8472056), + (6, -5.898753), + ], + // Generation 8 + [ + (7, -3.6624274), + (4, -2.9281456), + (39, -5.9176188), + (13, -8.0644045), + (16, -2.0319564), + (49, -10.309226), + (3, -0.21671781), + (37, -8.295551), + (44, -16.496105), + (46, -6.2466326), + (47, -3.5928986), + (19, -9.298591), + (1, -7.937351), + (15, -8.218504), + (6, -6.945601), + (25, -8.446054), + (12, -5.8477135), + (14, -3.9165816), + (17, -2.4864268), + (20, -7.97737), + (22, -5.347026), + (0, -6.0739775), + (32, -6.7568192), + (36, -4.730008), + (28, -9.923819), + (38, -8.677519), + (42, -4.668519), + (48, 0.14014988), + (5, -8.3167), + (8, -2.5030074), + (21, -1.8195568), + (27, -6.111103), + (45, -12.708131), + (35, -8.089076), + (11, -6.0151362), + (34, -13.688166), + (33, -11.375975), + (2, -4.1082373), + (24, -4.0867376), + (10, -4.2828474), + (41, -9.174506), + (43, -1.1505331), + (29, -3.7704785), + (18, -4.9493446), + (30, -3.727829), + (31, -6.490308), + (9, -6.0947385), + (40, -9.492185), + (26, -13.629112), + (23, -9.773454), + ], + // Generation 9 + [ + (12, -1.754871), + (41, 2.712658), + (24, -4.0929146), + (18, -4.9418926), + (44, -9.325021), + (8, -6.4423165), + (1, -0.0946085), + (5, -3.0156248), + (14, -5.29519), + (34, -10.763539), + (11, -7.304751), + (20, -6.8397574), + (22, -5.6720686), + (23, -7.829904), + (7, -3.8627372), + (6, -3.1108487), + (16, -8.803584), + (36, -13.916307), + (21, -10.142917), + (37, -12.171498), + (45, -13.004938), + (19, -3.7237267), + (47, -6.0189786), + (17, -4.612711), + (15, -5.3010545), + (30, -5.671092), + (46, -13.300519), + (25, -8.2948), + (3, -10.556543), + (42, -7.041272), + (48, -9.797744), + (9, -5.6163936), + (26, -6.665021), + (27, -7.074666), + (4, -1.5992731), + (2, -6.4931273), + (29, -3.9785416), + (31, -12.222026), + (10, -2.3970482), + (40, -6.204074), + (49, -7.025599), + (28, -8.562909), + (13, -6.2592154), + (32, -10.465271), + (33, -7.7043953), + (35, -6.4584246), + (38, -2.9016697), + (39, -1.5256255), + (43, -10.858711), + (0, -4.720929), + ], + //Generation 10 + [ + (2, -5.1676617), + (3, -4.521774), + (29, -7.3104324), + (23, -6.550776), + (26, -10.467587), + (18, 1.6576093), + (33, -2.564094), + (20, -3.2697926), + (35, -13.577334), + (37, -6.0147185), + (17, -4.07909), + (0, -9.630419), + (38, -7.011383), + (12, -10.686635), + (43, -8.94728), + (48, -9.350017), + (30, -7.3335466), + (13, -7.7690034), + (4, -2.3488472), + (14, -7.2594194), + (21, -9.08367), + (34, -7.7497597), + (8, -6.2317214), + (27, -8.440135), + (22, -4.4437346), + (32, -2.194015), + (28, -6.6919556), + (40, -8.840385), + (42, -9.781796), + (15, -7.3304253), + (49, -8.720987), + (19, -9.044103), + (6, -5.715863), + (41, -8.395639), + (36, -3.995482), + (25, -9.1373005), + (5, -7.5690002), + (1, -6.0397635), + (16, -8.231512), + (10, -6.5344634), + (44, -7.749376), + (7, -9.302668), + (31, -10.868391), + (39, -2.7578635), + (47, -6.964238), + (24, -4.033315), + (11, -8.211409), + (45, -10.472969), + (9, -7.1529093), + (46, -9.653514), + ], + ]; + + // Transform scores into a vector of hashmaps instead + let scores: Vec> = scores + .iter() + .map(|gen_scores| gen_scores.iter().cloned().collect()) + .collect(); + + assert!( + should_continue(scores[..0].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..1].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..2].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..3].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..4].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..5].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..6].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..7].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == true + ); + assert!( + should_continue(scores[..8].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == false + ); + assert!( + should_continue(scores[..9].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == false + ); + assert!( + should_continue(scores[..10].as_ref(), 5) + .expect("Failed to determine if the simulation should continue") + == false + ); + } +} diff --git a/gemla/src/bin/fighter_nn/neural_network_utility.rs b/gemla/src/bin/fighter_nn/neural_network_utility.rs new file mode 100644 index 0000000..7fe922e --- /dev/null +++ b/gemla/src/bin/fighter_nn/neural_network_utility.rs @@ -0,0 +1,1825 @@ +use std::{cmp::min, collections::HashMap, ops::Range}; + +use anyhow::Context; +use fann::{ActivationFunc, Fann}; +use gemla::error::Error; +use rand::{ + distributions::{Distribution, Uniform}, + seq::IteratorRandom, + thread_rng, Rng, +}; + +use super::{ + FighterNN, NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX, NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN, +}; + +/// Crossbreeds two neural networks of different shapes by finding cut points, and swapping neurons between the two networks. +/// Algorithm tries to ensure similar functionality is maintained between the two networks. +/// It does this by preserving connections between the same neurons from the original to the new network, and if a connection cannot be found +/// it will create a new connection with a random weight. +pub fn crossbreed( + fighter_nn: &FighterNN, + primary: &Fann, + secondary: &Fann, + crossbreed_segments: usize, +) -> Result { + // First we need to get the shape of the networks and transform this into a format that is easier to work with + // We want a list of every neuron id, and the layer it is in + let primary_shape = primary.get_layer_sizes(); + let secondary_shape = secondary.get_layer_sizes(); + let primary_neurons = generate_neuron_datastructure(&primary_shape); + let secondary_neurons = generate_neuron_datastructure(&secondary_shape); + + let segments = generate_segments(primary_shape, secondary_shape, crossbreed_segments); + + let new_neurons = crossbreed_neuron_arrays(segments, primary_neurons, secondary_neurons); + + // Now we need to create the new network with the shape we've determined + let mut new_shape = vec![]; + for (_, _, layer, _) in new_neurons.iter() { + // Check if new_shape has an entry for layer in it + if new_shape.len() <= *layer { + new_shape.push(1); + } else { + new_shape[*layer] += 1; + } + } + + let mut new_fann = + Fann::new(new_shape.as_slice()).with_context(|| "Failed to create new fann")?; + // We need to randomize the weights to a small value + new_fann.randomize_weights( + fighter_nn.weight_initialization_range.start, + fighter_nn.weight_initialization_range.end, + ); + new_fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); + new_fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); + + consolidate_old_connections(primary, secondary, new_shape, new_neurons, &mut new_fann); + + Ok(new_fann) +} + +pub fn generate_segments( + primary_shape: Vec, + secondary_shape: Vec, + crossbreed_segments: usize, +) -> Vec<(u32, u32)> { + // Now we need to find the cut points for the crossbreed + let start = primary_shape[0] + 1; + // Start at the first hidden layer + let end = min( + primary_shape.iter().sum::() - primary_shape.last().unwrap(), + secondary_shape.iter().sum::() - secondary_shape.last().unwrap(), + ); + // End at the last hidden layer + let segment_distribution = Uniform::from(start..end); + // Ensure segments are not too small + + let mut cut_points = Vec::new(); + for _ in 0..crossbreed_segments { + let cut_point = segment_distribution.sample(&mut thread_rng()); + if !cut_points.contains(&cut_point) { + cut_points.push(cut_point); + } + } + // Sort the cut points to make it easier to iterate over them + cut_points.sort_unstable(); + + // We need to transform the cut_points vector to a vector of tuples that contain the start and end of each segment + let mut segments = Vec::new(); + let mut previous = 0; + for &cut_point in cut_points.iter() { + segments.push((previous, cut_point - 1)); + previous = cut_point; + } + segments +} + +pub fn consolidate_old_connections( + primary: &Fann, + secondary: &Fann, + new_shape: Vec, + new_neurons: Vec<(u32, bool, usize, u32)>, + new_fann: &mut Fann, +) { + // Now we need to copy the connections from the original networks to the new network + // We can do this by referencing our connections array, it will contain the original id's of the neurons + // and their new id as well as their layer. We can iterate one layer at a time and copy the connections + + let primary_shape = primary.get_layer_sizes(); + let secondary_shape = secondary.get_layer_sizes(); + trace!("Primary shape: {:?}", primary_shape); + trace!("Secondary shape: {:?}", secondary_shape); + trace!("New shape: {:?}", new_shape); + + // Start by iterating layer by later + let primary_connections = primary.get_connections(); + let secondary_connections = secondary.get_connections(); + for layer in 1..new_shape.len() { + // filter out the connections that are in the current layer and previous layer + let current_layer_connections = new_neurons + .iter() + .filter(|(_, _, l, _)| l == &layer) + .collect::>(); + let previous_layer_connections = new_neurons + .iter() + .filter(|(_, _, l, _)| l == &(layer - 1)) + .collect::>(); + + // Now we need to iterate over the connections in the current layer + for (neuron_id, is_primary, _, new_id) in current_layer_connections.iter() { + // We need to find the connections from the previous layer to this neuron + for (previous_neuron_id, _, _, previous_new_id) in previous_layer_connections.iter() { + // First we use primary to and check the correct connections array to see if the connection exists + // If it does, we add it to the new network + let mut connection; + let mut found_in_primary = false; + if *is_primary { + connection = primary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + + // If both neurons have a Some value + if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { + from_neuron == *previous_neuron_id && to_neuron == *neuron_id + } else { + false + } + }); + + if connection.is_none() { + connection = secondary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + + // If both neurons have a Some value + if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { + from_neuron == *previous_neuron_id && to_neuron == *neuron_id + } else { + false + } + }); + } else { + found_in_primary = true; + } + } else { + connection = secondary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + + // If both neurons have a Some value + if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { + from_neuron == *previous_neuron_id && to_neuron == *neuron_id + } else { + false + } + }); + + if connection.is_none() { + connection = primary_connections.iter().find(|connection| { + let from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + + // If both neurons have a Some value + if let (Some(from_neuron), Some(to_neuron)) = (from_neuron, to_neuron) { + from_neuron == *previous_neuron_id && to_neuron == *neuron_id + } else { + false + } + }); + } else { + found_in_primary = true; + } + }; + + // If the connection exists, we need to add it to the new network + if let Some(connection) = connection { + if *is_primary { + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + trace!("Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", previous_new_id, new_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, previous_neuron_id, neuron_id); + } else { + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + trace!("Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", previous_new_id, new_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, previous_neuron_id, neuron_id); + } + let translated_from = to_bias_network_id(previous_new_id, &new_shape); + let translated_to = to_bias_network_id(new_id, &new_shape); + new_fann.set_weight(translated_from, translated_to, connection.weight); + } else { + trace!( + "Connection not found for ({}, {}) -> ({}, {})", + previous_new_id, + new_id, + previous_neuron_id, + neuron_id + ); + } + } + } + + // Add bias neuron connections + let bias_neuron = get_bias_neuron_for_layer(layer, &new_shape); + if let Some(bias_neuron) = bias_neuron { + // Loop through neurons in current layer + for (neuron_id, is_primary, _, new_id) in current_layer_connections.iter() { + let translated_neuron_id = to_bias_network_id(new_id, &new_shape); + + let mut connection = None; + let mut found_in_primary = false; + if *is_primary { + let primary_bias_neuron = get_bias_neuron_for_layer(layer, &primary_shape); + if let Some(primary_bias_neuron) = primary_bias_neuron { + connection = primary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + + if let Some(to_neuron) = to_neuron { + connection.from_neuron == primary_bias_neuron + && to_neuron == *neuron_id + } else { + false + } + }); + } + + if connection.is_none() { + let secondary_bias_neuron = + get_bias_neuron_for_layer(layer, &secondary_shape); + if let Some(secondary_bias_neuron) = secondary_bias_neuron { + connection = secondary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + + if let Some(to_neuron) = to_neuron { + connection.from_neuron == secondary_bias_neuron + && to_neuron == *neuron_id + } else { + false + } + }); + } + } else { + found_in_primary = true; + } + } else { + let secondary_bias_neuron = get_bias_neuron_for_layer(layer, &secondary_shape); + if let Some(secondary_bias_neuron) = secondary_bias_neuron { + connection = secondary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + + if let Some(to_neuron) = to_neuron { + connection.from_neuron == secondary_bias_neuron + && to_neuron == *neuron_id + } else { + false + } + }); + } + + if connection.is_none() { + let primary_bias_neuron = get_bias_neuron_for_layer(layer, &primary_shape); + if let Some(primary_bias_neuron) = primary_bias_neuron { + connection = primary_connections.iter().find(|connection| { + let to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + + if let Some(to_neuron) = to_neuron { + connection.from_neuron == primary_bias_neuron + && to_neuron == *neuron_id + } else { + false + } + }); + } + } else { + found_in_primary = true; + } + } + + if let Some(connection) = connection { + if *is_primary { + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &primary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &primary_shape); + trace!("Primary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", bias_neuron, translated_neuron_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, bias_neuron, neuron_id); + } else { + let original_from_neuron = + to_non_bias_network_id(connection.from_neuron, &secondary_shape); + let original_to_neuron = + to_non_bias_network_id(connection.to_neuron, &secondary_shape); + trace!("Secondary: Adding connection from ({} -> {}) translated to ({:?} -> {:?}) with weight {} for primary:{} [{} -> {}] [{} -> {}]", bias_neuron, translated_neuron_id, original_from_neuron, original_to_neuron, connection.weight, found_in_primary, connection.from_neuron, connection.to_neuron, bias_neuron, neuron_id); + } + new_fann.set_weight(bias_neuron, translated_neuron_id, connection.weight); + } else { + trace!( + "Connection not found for bias ({}, {}) -> ({}, {}) primary: {}", + bias_neuron, + neuron_id, + bias_neuron, + translated_neuron_id, + is_primary + ); + } + } + } + } +} + +pub fn crossbreed_neuron_arrays( + segments: Vec<(u32, u32)>, + primary_neurons: Vec<(u32, usize)>, + secondary_neurons: Vec<(u32, usize)>, +) -> Vec<(u32, bool, usize, u32)> { + // We now need to determine the resulting location of the neurons in the new network. + // To do this we need a new structure that keeps track of the following information: + // - The neuron id from the original network + // - Which network it originated from (primary or secondary) + // - The layer the neuron is in + // - The resulting neuron id in the new network which will be calculated after the fact + let mut new_neurons = Vec::new(); + let mut current_layer = 0; + // keep track of the last layer that we inserted a neuron into for each network + let mut primary_last_layer = 0; + let mut secondary_last_layer = 0; + let mut is_primary = true; + for (i, &segment) in segments.iter().enumerate() { + // If it's the first slice, copy neurons from the primary network up to the cut_point + if i == 0 { + for (neuron_id, layer) in primary_neurons.iter() { + if neuron_id <= &segment.1 { + if layer > ¤t_layer { + current_layer += 1; + } + new_neurons.push((*neuron_id, is_primary, current_layer, 0)); + if is_primary { + primary_last_layer = current_layer; + } else { + secondary_last_layer = current_layer; + } + } else { + break; + } + } + } else { + let target_neurons = if is_primary { + &primary_neurons + } else { + &secondary_neurons + }; + + for (neuron_id, layer) in target_neurons.iter() { + // Iterate until neuron_id equals the cut_point + if neuron_id >= &segment.0 && neuron_id <= &segment.1 { + // We need to do something different depending on whether the neuron layer is, lower, higher or equal to the target layer + + // Equal + if layer == ¤t_layer { + new_neurons.push((*neuron_id, is_primary, current_layer, 0)); + + if is_primary { + primary_last_layer = current_layer; + } else { + secondary_last_layer = current_layer; + } + } + // Earlier + else if layer < ¤t_layer { + // If it's in an earlier layer, add it to the earlier layer + // Check if there's a lower id from the same individual in that earlier layer + // As long as there isn't a neuron from the other individual in between the lower id and current id, add the id values from the same individual + let earlier_layer_neurons = new_neurons + .iter() + .filter(|(_, _, l, _)| l == layer) + .collect::>(); + // get max id from that layer + let highest_id = earlier_layer_neurons + .iter() + .max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); + if let Some(highest_id) = highest_id { + if highest_id.1 == is_primary { + let neurons_to_add = target_neurons + .iter() + .filter(|(id, l)| { + id > &highest_id.0 && id < neuron_id && l == layer + }) + .collect::>(); + for (neuron_id, layer) in neurons_to_add { + new_neurons.push((*neuron_id, is_primary, *layer, 0)); + + if is_primary { + primary_last_layer = *layer; + } else { + secondary_last_layer = *layer; + } + } + } + } + + new_neurons.push((*neuron_id, is_primary, *layer, 0)); + + if is_primary { + primary_last_layer = *layer; + } else { + secondary_last_layer = *layer; + } + } + // Later + else if layer > ¤t_layer { + // If the highest id in the current layer is from the same individual, add anything with a higher id to the current layer before moving to the next layer + // First filter new_neurons to look at neurons from the current layer + let current_layer_neurons = new_neurons + .iter() + .filter(|(_, _, l, _)| l == ¤t_layer) + .collect::>(); + let highest_id = + current_layer_neurons.iter().max_by_key(|(id, _, _, _)| id); + if let Some(highest_id) = highest_id { + if highest_id.1 == is_primary { + let neurons_to_add = target_neurons + .iter() + .filter(|(id, l)| id > &highest_id.0 && *l == layer - 1) + .collect::>(); + for (neuron_id, _) in neurons_to_add { + new_neurons.push((*neuron_id, is_primary, current_layer, 0)); + + if is_primary { + primary_last_layer = current_layer; + } else { + secondary_last_layer = current_layer; + } + } + } + } + + // If it's in a future layer, move to the next layer + current_layer += 1; + + // Add the neuron to the new network + // Along with any neurons that have a lower id in the future layer + let neurons_to_add = target_neurons + .iter() + .filter(|(id, l)| id <= neuron_id && l == layer) + .collect::>(); + for (neuron_id, _) in neurons_to_add { + new_neurons.push((*neuron_id, is_primary, current_layer, 0)); + + if is_primary { + primary_last_layer = current_layer; + } else { + secondary_last_layer = current_layer; + } + } + } + } else if neuron_id >= &segment.1 { + break; + } + } + } + + // Switch to the other network + is_primary = !is_primary; + } + + // For the last segment, copy the remaining neurons + let target_neurons = if is_primary { + &primary_neurons + } else { + &secondary_neurons + }; + // Get output layer number + let output_layer = target_neurons.iter().max_by_key(|(_, l)| l).unwrap().1; + + // For the last segment, copy the remaining neurons from the target network + // But when we reach the output layer, we need to add a new layer to the end of new_neurons regardless of it's length + // and copy the output neurons to that layer + for (neuron_id, layer) in target_neurons.iter() { + if neuron_id > &segments.last().unwrap().1 { + if layer == &output_layer { + // Calculate which layer the neurons should be in + current_layer = new_neurons.iter().max_by_key(|(_, _, l, _)| l).unwrap().2 + 1; + for (neuron_id, _) in target_neurons.iter().filter(|(_, l)| l == &output_layer) { + new_neurons.push((*neuron_id, is_primary, current_layer, 0)); + } + break; + } else if *neuron_id == &segments.last().unwrap().1 + 1 { + let target_layer = if is_primary { + primary_last_layer + } else { + secondary_last_layer + }; + let earlier_layer_neurons = new_neurons + .iter() + .filter(|(_, _, l, _)| *l >= target_layer && l <= layer) + .collect::>(); + // get max neuron from with both + // The highest layer + // get max id from that layer + let highest_id = earlier_layer_neurons + .iter() + .max_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); + if let Some(highest_id) = highest_id { + if highest_id.1 == is_primary { + let neurons_to_add = target_neurons + .iter() + .filter(|(id, _)| id > &highest_id.0 && id < neuron_id) + .collect::>(); + for (neuron_id, l) in neurons_to_add { + new_neurons.push((*neuron_id, is_primary, *l, 0)); + } + } + } + + new_neurons.push((*neuron_id, is_primary, *layer, 0)); + } else { + new_neurons.push((*neuron_id, is_primary, *layer, 0)); + } + } + } + + // Filtering layers with too few neurons, if necessary + let layer_counts = new_neurons.iter().fold( + vec![0; current_layer + 1], + |mut counts, &(_, _, layer, _)| { + counts[layer] += 1; + counts + }, + ); + + // Filter out layers based on the minimum number of neurons per layer + new_neurons = new_neurons + .into_iter() + .filter(|&(_, _, layer, _)| layer_counts[layer] >= NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN) + .collect::>(); + + // If a layer has more than NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX, remove the neurons with the highest id + for layer in 1..layer_counts.len() - 1 { + let new_neurons_clone = new_neurons.clone(); + let layer_neurons = new_neurons_clone + .iter() + .filter(|(_, _, l, _)| l == &layer) + .collect::>(); + if layer_neurons.len() > NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX { + let mut sorted_neurons = layer_neurons.clone(); + // Take primary neurons first, order by highest id + sorted_neurons.sort_by(|a, b| a.1.cmp(&b.1).then(a.0.cmp(&b.0))); + let neurons_to_remove = sorted_neurons.len() - NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX; + for _ in 0..neurons_to_remove { + let neuron_to_remove = sorted_neurons.pop().unwrap(); + new_neurons.retain(|neuron| neuron != neuron_to_remove); + } + } + } + + // Collect and sort unique layer numbers + let mut unique_layers = new_neurons + .iter() + .map(|(_, _, layer, _)| *layer) + .collect::>(); + unique_layers.sort(); + unique_layers.dedup(); // Removes duplicates, keeping only unique layer numbers + + // Create a mapping from old layer numbers to new (gap-less) layer numbers + let layer_mapping = unique_layers + .iter() + .enumerate() + .map(|(new_layer, &old_layer)| (old_layer, new_layer)) + .collect::>(); + + // Apply the mapping to renumber layers in new_neurons + new_neurons.iter_mut().for_each(|(_, _, layer, _)| { + *layer = *layer_mapping.get(layer).unwrap_or(layer); // Fallback to original layer if not found, though it should always find a match + }); + + // Assign new IDs + // new_neurons must be sorted by layer, then by neuron ID within the layer + new_neurons.sort_unstable_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); + new_neurons + .iter_mut() + .enumerate() + .for_each(|(new_id, neuron)| { + neuron.3 = new_id as u32; + }); + + new_neurons +} + +pub fn major_mutation(fann: &Fann, weight_initialization_range: Range) -> Result { + // add or remove a random neuron from a hidden layer + let mut mutated_shape = fann.get_layer_sizes().to_vec(); + let mut mutated_neurons = generate_neuron_datastructure(&mutated_shape) + .iter() + .map(|(id, layer)| (*id, true, *layer, *id)) + .collect::>(); + + // Determine first whether to add or remove a neuron + if thread_rng().gen_bool(0.5) { + // To add a neuron we need to create a new fann object with the new layer sizes, then copy the information and connections over + let max_id = mutated_neurons + .iter() + .max_by_key(|(id, _, _, _)| id) + .unwrap() + .0; + + // Now we inject the new neuron into mutated_neurons + let layer = thread_rng().gen_range(1..fann.get_num_layers() - 1) as usize; + // Do not add to layer if it would result in more than NEURALNETWORK_HIDDEN_LAYER_SIZE_MAX neurons + if mutated_shape[layer] < NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MAX as u32 { + let new_id = max_id + 1; + mutated_neurons.push((new_id, true, layer, new_id)); + mutated_shape[layer] += 1; + } + } else { + // Remove a neuron + let layer = thread_rng().gen_range(1..fann.get_num_layers() - 1) as usize; + // Do not remove from layer if it would result in less than NEURALNETWORK_HIDDEN_LAYER_SIZE_MIN neurons + if mutated_shape[layer] > NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN as u32 { + let remove_id = mutated_neurons + .iter() + .filter(|(_, _, l, _)| l == &layer) + .choose(&mut thread_rng()) + .unwrap() + .0; + mutated_neurons.retain(|(id, _, _, _)| id != &remove_id); + mutated_shape[layer] -= 1; + } + } + + let mut mutated_fann = + Fann::new(mutated_shape.as_slice()).with_context(|| "Failed to create new fann")?; + mutated_fann.randomize_weights( + weight_initialization_range.start, + weight_initialization_range.end, + ); + mutated_fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric); + mutated_fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric); + + // We need to regenerate the new_id's in mutated_neurons (the 4th item in the tuple) we can do this by iterating over the mutated_neurons all over again starting from ZERO + mutated_neurons.sort_by(|a, b| a.2.cmp(&b.2).then(a.0.cmp(&b.0))); + for (i, (_, _, _, new_id)) in mutated_neurons.iter_mut().enumerate() { + *new_id = i as u32; + } + + // We need to copy the connections from the old fann to the new fann + consolidate_old_connections( + fann, + fann, + mutated_shape, + mutated_neurons, + &mut mutated_fann, + ); + + Ok(mutated_fann) +} + +pub fn generate_neuron_datastructure(shape: &[u32]) -> Vec<(u32, usize)> { + let mut result = Vec::new(); + let mut global_index = 0; // Keep a global index that does not reset + + for (layer_index, &neurons) in shape.iter().enumerate() { + for _ in 0..neurons { + result.push((global_index, layer_index)); + global_index += 1; // Increment global index for each neuron + } + // global_index += 1; // Skip index for bias neuron at the end of each layer + } + + result +} + +fn to_bias_network_id(id: &u32, shape: &[u32]) -> u32 { + // The given id comes from a network without a bias neuron at the end of every layer + // We need to translate this id to the id in the network with bias neurons + let mut translated_id = 0; + for (layer_index, &neurons) in shape.iter().enumerate() { + for _ in 0..neurons { + if &translated_id == id { + return translated_id + layer_index as u32; + } + translated_id += 1; + } + } + + // If the id is not found, return the id + translated_id +} + +fn to_non_bias_network_id(id: u32, shape: &[u32]) -> Option { + let mut total_neurons = 0; // Total count of neurons (excluding bias neurons) processed + + for (bias_count, &neurons) in shape.iter().enumerate() { + let layer_end = total_neurons + neurons; // End of the current layer, excluding the bias neuron + if id < layer_end { + // ID is within the current layer (excluding the bias neuron) + return Some(id - bias_count as u32); + } + if id == layer_end { + // ID matches the position where a bias neuron would be + return None; + } + + // Update counts after considering the current layer + total_neurons += neurons + 1; // Move to the next layer, accounting for the bias neuron + } + + // If the ID is beyond the range of all neurons (including bias), it's treated as invalid + // Adjust this behavior based on your application's needs + None +} + +fn get_bias_neuron_for_layer(layer: usize, shape: &[u32]) -> Option { + if layer == 0 || layer >= shape.len() { + // No bias neuron for the first and last layers + None + } else { + // Compute the bias neuron for intermediate layers + let mut bias = 0; + for layer_count in shape.iter().take(layer) { + bias += layer_count; + } + Some(bias + layer as u32 - 1) + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use super::*; + + #[test] + fn major_mutation_test() -> Result<(), Box> { + // Assign + let primary_shape = vec![2, 8, 5, 3, 1]; + // [2, 11, 17, 21] + // [0, 1, 2, 3] + + let mut primary_fann = Fann::new(&primary_shape)?; + + let mut primary_connections = primary_fann.get_connections(); + for connection in primary_connections.iter_mut() { + connection.weight = ((connection.from_neuron * 100) + connection.to_neuron) as f32; + } + primary_fann.set_connections(&primary_connections); + + let weight_initialization_range = -1.0..-0.5; + + for _ in 0..100 { + let result = major_mutation(&primary_fann, weight_initialization_range.clone())?; + + let connections = result.get_connections(); + for connection in connections.iter() { + println!("Connection: {:?}", connection); + } + + let new_shape = result.get_layer_sizes(); + println!("New Shape: {:?}", new_shape); + + // Assert that input and output layers have the same size + assert_eq!(primary_shape[0], new_shape[0]); + assert_eq!( + primary_shape[primary_shape.len() - 1], + new_shape[new_shape.len() - 1] + ); + + // Determine if a neuron was removed or added + if new_shape.iter().sum::() == primary_shape.iter().sum::() + 1 { + //Neuron was added + // Find id of neuron that was added + let mut added_neuron_id = 0; + let matching_layers = new_shape.iter().zip(primary_shape.iter()); + for (i, (new_layer, primary_layer)) in matching_layers.enumerate() { + if new_layer > primary_layer { + added_neuron_id += primary_layer + i as u32; + break; + } + added_neuron_id += primary_layer; + } + + for connection in connections.iter() { + if connection.from_neuron == added_neuron_id + || connection.to_neuron == added_neuron_id + { + assert!( + connection.weight < 0.0, + "Connection: {:?}, Added Neuron: {}", + connection, + added_neuron_id + ); + } else { + assert!( + connection.weight > 0.0, + "Connection: {:?}, Added Neuron: {}", + connection, + added_neuron_id + ); + } + } + } else if new_shape.iter().sum::() == primary_shape.iter().sum::() - 1 { + //Neuron was removed + for connection in connections.iter() { + assert!(connection.weight > 0.0, "Connection: {:?}", connection); + } + + for (i, layer) in new_shape.iter().enumerate() { + // if layer isn't input or output + if i != 0 && i as u32 != new_shape.len() as u32 - 1 { + assert!( + *layer >= NEURAL_NETWORK_HIDDEN_LAYER_SIZE_MIN as u32, + "Layer: {}", + layer + ); + } + } + } else { + //Neuron was neither added nor removed + for connection in connections.iter() { + assert!(connection.weight > 0.0, "Connection: {:?}", connection); + } + } + } + + Ok(()) + } + + #[test] + fn generate_segments_test() { + // Assign + let primary_shape = vec![4, 8, 6, 4]; + let secondary_shape = vec![4, 3, 3, 3, 3, 3, 4]; + let crossbreed_segments = 5; + + // Act + let result = generate_segments( + primary_shape.clone(), + secondary_shape.clone(), + crossbreed_segments, + ); + + println!("{:?}", result); + + // Assert + assert!( + result.len() <= crossbreed_segments, + "Segments: {:?}", + result + ); + //Assert that segments are within the bounds of the layers + for (start, end) in result.iter() { + // Bounds are the end of the first layer to the end of the second to last layer + let bounds = 3..17; + + assert!(bounds.contains(end)); + assert!(start < &bounds.end); + } + + //Assert that segments start and end are in ascending order + for (start, end) in result.iter() { + assert!(*start <= *end, "Start: {}, End: {}", start, end); + } + + // Test that segments are contiguous + for i in 0..result.len() - 1 { + assert_eq!(result[i].1 + 1, result[i + 1].0); + } + + // Testing with more segments than possible + let crossbreed_segments = 15; + + // Act + let result = generate_segments( + primary_shape.clone(), + secondary_shape.clone(), + crossbreed_segments, + ); + + println!("{:?}", result); + + //Assert that segments are within the bounds of the layers + for (start, end) in result.iter() { + // Bounds are the end of the first layer to the end of the second to last layer + let bounds = 3..17; + + assert!(bounds.contains(end)); + assert!(start < &bounds.end); + } + + //Assert that segments start and end are in ascending order + for (start, end) in result.iter() { + assert!(*start <= *end, "Start: {}, End: {}", start, end); + } + + // Test that segments are contiguous + for i in 0..result.len() - 1 { + assert_eq!(result[i].1 + 1, result[i + 1].0); + } + } + + #[test] + fn get_bias_neuron_for_layer_test() { + // Assign + let shape = vec![4, 8, 6, 4]; + + // Act + let result = get_bias_neuron_for_layer(0, &shape); + + // Assert + assert_eq!(result, None); + + // Act + let result = get_bias_neuron_for_layer(1, &shape); + + // Assert + assert_eq!(result, Some(4)); + + // Act + let result = get_bias_neuron_for_layer(2, &shape); + + // Assert + assert_eq!(result, Some(13)); + + // Act + let result = get_bias_neuron_for_layer(3, &shape); + + // Assert + assert_eq!(result, Some(20)); + + // Act + let result = get_bias_neuron_for_layer(4, &shape); + + // Assert + assert_eq!(result, None); + } + + #[test] + fn crossbreed_neuron_arrays_test() { + // Assign + let segments = vec![(0, 3), (4, 6), (7, 8), (9, 10)]; + + let primary_network = generate_neuron_datastructure(&vec![4, 8, 6, 4]); + + let secondary_network = generate_neuron_datastructure(&vec![4, 3, 3, 3, 3, 3, 4]); + + // Act + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 8 + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, true, 1, 7), + (8, true, 1, 8), + (9, true, 1, 9), + (10, true, 1, 10), + (11, true, 1, 11), + // Hidden Layer 2: Expect 9 + (7, false, 2, 12), + (8, false, 2, 13), + (9, false, 2, 14), + (12, true, 2, 15), + (13, true, 2, 16), + (14, true, 2, 17), + (15, true, 2, 18), + (16, true, 2, 19), + (17, true, 2, 20), + // Output Layer: Expect 4 + (18, true, 3, 21), + (19, true, 3, 22), + (20, true, 3, 23), + (21, true, 3, 24), + ] + .into_iter() + .collect(); + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Now we test the ooposite case + // Act + let result = crossbreed_neuron_arrays( + segments.clone(), + secondary_network.clone(), + primary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 7 + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, false, 1, 7), + (8, false, 1, 8), + (9, false, 1, 9), + (10, false, 1, 10), + // Hidden Layer 2: Expect 3 + (7, true, 2, 11), + (8, true, 2, 12), + (9, true, 2, 13), + // Hidden Layer 3: Expect 3 + (10, true, 3, 14), + (11, true, 3, 15), + (12, true, 3, 16), + // Hidden Layer 4: Expect 3 + (13, true, 4, 17), + (14, true, 4, 18), + (15, true, 4, 19), + // Hidden Layer 5: Expect 3 + (16, true, 5, 20), + (17, true, 5, 21), + (18, true, 5, 22), + // Output Layer: Expect 4 + (19, true, 6, 23), + (20, true, 6, 24), + (21, true, 6, 25), + (22, true, 6, 26), + ] + .into_iter() + .collect(); + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Testing with a different segment + // Assign + let segments = vec![(0, 4), (5, 14), (15, 15), (16, 16)]; + + // Act + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 3 + (4, true, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + // Hidden Layer 2: Expect 6 + (7, false, 2, 7), + (8, false, 2, 8), + (9, false, 2, 9), + (15, true, 2, 10), + (16, true, 2, 11), + (17, true, 2, 12), + // Hidden Layer 3: Expect 3 + (10, false, 3, 13), + (11, false, 3, 14), + (12, false, 3, 15), + // Hidden Layer 4: Expect 3 + (13, false, 4, 16), + (14, false, 4, 17), + (15, false, 4, 18), + // Output Layer: Expect 4 + (18, true, 5, 19), + (19, true, 5, 20), + (20, true, 5, 21), + (21, true, 5, 22), + ] + .into_iter() + .collect(); + + // print result before comparison + for r in result.iter() { + println!("{:?}", r); + } + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Swapping order + let result = crossbreed_neuron_arrays( + segments.clone(), + secondary_network.clone(), + primary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 8 + (4, true, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, false, 1, 7), + (8, false, 1, 8), + (9, false, 1, 9), + (10, false, 1, 10), + (11, false, 1, 11), + // Hidden Layer 2: Expect 5 + (12, false, 2, 12), + (13, false, 2, 13), + (14, false, 2, 14), + (15, false, 2, 15), + (16, false, 2, 16), + // Hidden Layer 3: Expect 3 + (13, true, 3, 17), + (14, true, 3, 18), + (15, true, 3, 19), + // Hidden Layer 4: Expect 3 + (16, true, 4, 20), + (17, true, 4, 21), + (18, true, 4, 22), + // Output Layer: Expect 4 + (19, true, 5, 23), + (20, true, 5, 24), + (21, true, 5, 25), + (22, true, 5, 26), + ] + .into_iter() + .collect(); + + // print result before comparison + for r in result.iter() { + println!("{:?}", r); + } + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Testing with a different segment + // Assign + let segments = vec![(0, 7), (8, 9), (10, 10), (11, 12)]; + + // Act + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 7 + (4, true, 1, 4), + (5, true, 1, 5), + (6, true, 1, 6), + (7, true, 1, 7), + (8, true, 1, 8), + (9, true, 1, 9), + (10, true, 1, 10), + // Hidden Layer 2: Expect 8 + (7, false, 2, 11), + (8, false, 2, 12), + (9, false, 2, 13), + (13, true, 2, 14), + (14, true, 2, 15), + (15, true, 2, 16), + (16, true, 2, 17), + (17, true, 2, 18), + // Hidden Layer 3: Expect 3 + (10, false, 3, 19), + (11, false, 3, 20), + (12, false, 3, 21), + // Output Layer: Expect 4 + (18, true, 4, 22), + (19, true, 4, 23), + (20, true, 4, 24), + (21, true, 4, 25), + ] + .into_iter() + .collect(); + + // print result before comparison + for r in result.iter() { + println!("{:?}", r); + } + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Swapping order + let result = crossbreed_neuron_arrays( + segments.clone(), + secondary_network.clone(), + primary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 7 + (4, true, 1, 4), + (5, true, 1, 5), + (6, true, 1, 6), + (8, false, 1, 7), + (9, false, 1, 8), + (10, false, 1, 9), + (11, false, 1, 10), + // Hidden Layer 2: Expect 4 + (7, true, 2, 11), + (8, true, 2, 12), + (9, true, 2, 13), + (12, false, 2, 14), + // Hidden Layer 3: Expect 3 + (10, true, 3, 15), + (11, true, 3, 16), + (12, true, 3, 17), + // Hidden Layer 4: Expect 3 + (13, true, 4, 18), + (14, true, 4, 19), + (15, true, 4, 20), + // Hidden Layer 5: Expect 3 + (16, true, 5, 21), + (17, true, 5, 22), + (18, true, 5, 23), + // Output Layer: Expect 4 + (19, true, 6, 24), + (20, true, 6, 25), + (21, true, 6, 26), + (22, true, 6, 27), + ] + .into_iter() + .collect(); + + // print result before comparison + for r in result.iter() { + println!("{:?}", r); + } + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Testing networks with the same size + // Assign + let segments = vec![(0, 3), (4, 6), (7, 8), (9, 11)]; + + let primary_network = generate_neuron_datastructure(&vec![4, 3, 4, 5, 4]); + + vec![ + // Input layer + (0, 0), + (1, 0), + (2, 0), + (3, 0), + // Hidden layer 1: 3 neurons + (4, 1), + (5, 1), + (6, 1), + // Hidden Layer 2: 4 neurons + (7, 2), + (8, 2), + (9, 2), + (10, 2), + // Hidden Layer 3: 5 neurons + (11, 3), + (12, 3), + (13, 3), + (14, 3), + (15, 3), + // Output layer + (16, 4), + (17, 4), + (18, 4), + (19, 4), + ]; + + let secondary_network = primary_network.clone(); + + // Act + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 3 + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + // Hidden Layer 2: Expect 4 + (7, true, 2, 7), + (8, true, 2, 8), + (9, false, 2, 9), + (10, false, 2, 10), + // Hidden Layer 3: Expect 5 + (11, false, 3, 11), + (12, true, 3, 12), + (13, true, 3, 13), + (14, true, 3, 14), + (15, true, 3, 15), + // Output Layer: Expect 4 + (16, true, 4, 16), + (17, true, 4, 17), + (18, true, 4, 18), + (19, true, 4, 19), + ] + .into_iter() + .collect(); + + // print result before comparison + for r in result.iter() { + println!("{:?}", r); + } + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + + // Testing with different segment + let segments = vec![(0, 5), (6, 6), (7, 11), (12, 13)]; + + // Act + let result = crossbreed_neuron_arrays( + segments.clone(), + primary_network.clone(), + secondary_network.clone(), + ); + + // Expected Result Set + let expected: HashSet<(u32, bool, usize, u32)> = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 3 + (4, true, 1, 4), + (5, true, 1, 5), + (6, false, 1, 6), + // Hidden Layer 2: Expect 4 + (7, true, 2, 7), + (8, true, 2, 8), + (9, true, 2, 9), + (10, true, 2, 10), + // Hidden Layer 3: Expect 5 + (11, true, 3, 11), + (12, false, 3, 12), + (13, false, 3, 13), + (14, true, 3, 14), + (15, true, 3, 15), + // Output Layer: Expect 4 + (16, true, 4, 16), + (17, true, 4, 17), + (18, true, 4, 18), + (19, true, 4, 19), + ] + .into_iter() + .collect(); + + // print result before comparison + for r in result.iter() { + println!("{:?}", r); + } + + // Convert Result to HashSet for Comparison + let result_set: HashSet<(u32, bool, usize, u32)> = result.into_iter().collect(); + + // Assert + assert_eq!(result_set, expected); + } + + #[test] + fn generate_neuron_datastructure_test() { + // Assign + let shape = vec![4, 3, 5, 4]; + + // Act + let result = generate_neuron_datastructure(shape.as_slice()); + + // Expected Result + let expected: Vec<(u32, usize)> = vec![ + (0, 0), + (1, 0), + (2, 0), + (3, 0), + (4, 1), + (5, 1), + (6, 1), + (7, 2), + (8, 2), + (9, 2), + (10, 2), + (11, 2), + (12, 3), + (13, 3), + (14, 3), + (15, 3), + ]; + + // Assert + assert_eq!(result, expected); + } + + #[test] + fn translate_neuron_id_test() { + // Assign + let shape = vec![4, 3, 5, 4]; + + let expected = vec![ + // (input, expected output) + (0, 0), + (1, 1), + (2, 2), + (3, 3), + (4, 5), + (5, 6), + (6, 7), + (7, 9), + (8, 10), + (9, 11), + (10, 12), + (11, 13), + (12, 15), + (13, 16), + (14, 17), + (15, 18), + ]; + + // Act + for (input, expected_output) in expected { + let result = to_bias_network_id(&input, &shape); + // Assert + assert_eq!(result, expected_output); + + // Go the other direction too + let result = to_non_bias_network_id(expected_output, &shape); + + // Assert + if let Some(result) = result { + assert_eq!(result, input); + } else { + assert!(false, "Expected Some, got None"); + } + } + + // Validate bias neuron values + let bias_neurons = vec![4, 8, 14, 19]; + + for &bias_neuron in bias_neurons.iter() { + let result = to_non_bias_network_id(bias_neuron, &shape); + + // Assert + assert!(result.is_none()); + } + } + + #[test] + fn consolidate_old_connections_test() -> Result<(), Box> { + // Assign + let primary_shape = vec![4, 8, 6, 4]; + let secondary_shape = vec![4, 3, 3, 3, 3, 3, 4]; + + let mut primary_fann = Fann::new(&primary_shape)?; + let mut secondary_fann = Fann::new(&secondary_shape)?; + + let mut primary_connections = primary_fann.get_connections(); + for connection in primary_connections.iter_mut() { + connection.weight = ((connection.from_neuron * 100) + connection.to_neuron) as f32; + } + primary_fann.set_connections(&primary_connections); + + let mut secondary_connections = secondary_fann.get_connections(); + for connection in secondary_connections.iter_mut() { + connection.weight = ((connection.from_neuron * 100) + connection.to_neuron) as f32; + connection.weight = connection.weight * -1.0; + } + secondary_fann.set_connections(&secondary_connections); + + let new_neurons = vec![ + // Input layer: Expect 4 + (0, true, 0, 0), + (1, true, 0, 1), + (2, true, 0, 2), + (3, true, 0, 3), + // Hidden Layer 1: Expect 8 + (4, false, 1, 4), + (5, false, 1, 5), + (6, false, 1, 6), + (7, true, 1, 7), + (8, true, 1, 8), + (9, true, 1, 9), + (10, true, 1, 10), + (11, true, 1, 11), + // Hidden Layer 2: Expect 9 + (7, false, 2, 12), + (8, false, 2, 13), + (9, false, 2, 14), + (12, true, 2, 15), + (13, true, 2, 16), + (14, true, 2, 17), + (15, true, 2, 18), + (16, true, 2, 19), + (17, true, 2, 20), + // Output Layer: Expect 4 + (18, true, 3, 21), + (19, true, 3, 22), + (20, true, 3, 23), + (21, true, 3, 24), + ]; + let new_shape = vec![4, 8, 9, 4]; + let mut new_fann = Fann::new(&[4, 8, 9, 4])?; + // Initialize weights to 0 + let mut new_connections = new_fann.get_connections(); + for connection in new_connections.iter_mut() { + connection.weight = 0.0; + } + new_fann.set_connections(&new_connections); + + // Act + consolidate_old_connections( + &primary_fann, + &secondary_fann, + new_shape, + new_neurons, + &mut new_fann, + ); + + // Bias neurons + // Layer 1: 4 + // Layer 2: 13 + // Layer 3: 23 + let expected_connections = vec![ + // (from_neuron, to_neuron, weight) + // Hidden Layer 1 (5-12) + (0, 5, -5.0), + (1, 5, -105.0), + (2, 5, -205.0), + (3, 5, -305.0), + (0, 6, -6.0), + (1, 6, -106.0), + (2, 6, -206.0), + (3, 6, -306.0), + (0, 7, -7.0), + (1, 7, -107.0), + (2, 7, -207.0), + (3, 7, -307.0), + (0, 8, 8.0), + (1, 8, 108.0), + (2, 8, 208.0), + (3, 8, 308.0), + (0, 9, 9.0), + (1, 9, 109.0), + (2, 9, 209.0), + (3, 9, 309.0), + (0, 10, 10.0), + (1, 10, 110.0), + (2, 10, 210.0), + (3, 10, 310.0), + (0, 11, 11.0), + (1, 11, 111.0), + (2, 11, 211.0), + (3, 11, 311.0), + (0, 12, 12.0), + (1, 12, 112.0), + (2, 12, 212.0), + (3, 12, 312.0), + // Hidden Layer 2 (14-22) + (5, 14, -509.0), + (6, 14, -609.0), + (7, 14, -709.0), + (8, 14, 0.0), + (9, 14, 0.0), + (10, 14, 0.0), + (11, 14, 0.0), + (12, 14, 0.0), + (5, 15, -510.0), + (6, 15, -610.0), + (7, 15, -710.0), + (8, 15, 0.0), + (9, 15, 0.0), + (10, 15, 0.0), + (11, 15, 0.0), + (12, 15, 0.0), + (5, 16, -511.0), + (6, 16, -611.0), + (7, 16, -711.0), + (8, 16, 0.0), + (9, 16, 0.0), + (10, 16, 0.0), + (11, 16, 0.0), + (12, 16, 0.0), + (5, 17, 514.0), + (6, 17, 614.0), + (7, 17, 714.0), + (8, 17, 814.0), + (9, 17, 914.0), + (10, 17, 1014.0), + (11, 17, 1114.0), + (12, 17, 1214.0), + (5, 18, 515.0), + (6, 18, 615.0), + (7, 18, 715.0), + (8, 18, 815.0), + (9, 18, 915.0), + (10, 18, 1015.0), + (11, 18, 1115.0), + (12, 18, 1215.0), + (5, 19, 516.0), + (6, 19, 616.0), + (7, 19, 716.0), + (8, 19, 816.0), + (9, 19, 916.0), + (10, 19, 1016.0), + (11, 19, 1116.0), + (12, 19, 1216.0), + (5, 20, 517.0), + (6, 20, 617.0), + (7, 20, 717.0), + (8, 20, 817.0), + (9, 20, 917.0), + (10, 20, 1017.0), + (11, 20, 1117.0), + (12, 20, 1217.0), + (5, 21, 518.0), + (6, 21, 618.0), + (7, 21, 718.0), + (8, 21, 818.0), + (9, 21, 918.0), + (10, 21, 1018.0), + (11, 21, 1118.0), + (12, 21, 1218.0), + (5, 22, 519.0), + (6, 22, 619.0), + (7, 22, 719.0), + (8, 22, 819.0), + (9, 22, 919.0), + (10, 22, 1019.0), + (11, 22, 1119.0), + (12, 22, 1219.0), + // Output layer (24-27) + (14, 24, 0.0), + (15, 24, 0.0), + (16, 24, 0.0), + (17, 24, 1421.0), + (18, 24, 1521.0), + (19, 24, 1621.0), + (20, 24, 1721.0), + (21, 24, 1821.0), + (22, 24, 1921.0), + (14, 25, 0.0), + (15, 25, 0.0), + (16, 25, 0.0), + (17, 25, 1422.0), + (18, 25, 1522.0), + (19, 25, 1622.0), + (20, 25, 1722.0), + (21, 25, 1822.0), + (22, 25, 1922.0), + (14, 26, 0.0), + (15, 26, 0.0), + (16, 26, 0.0), + (17, 26, 1423.0), + (18, 26, 1523.0), + (19, 26, 1623.0), + (20, 26, 1723.0), + (21, 26, 1823.0), + (22, 26, 1923.0), + (14, 27, 0.0), + (15, 27, 0.0), + (16, 27, 0.0), + (17, 27, 1424.0), + (18, 27, 1524.0), + (19, 27, 1624.0), + (20, 27, 1724.0), + (21, 27, 1824.0), + (22, 27, 1924.0), + ]; + + for connection in new_fann.get_connections().iter() { + println!("{:?}", connection); + } + + // Assert + // Compare each connection to the expected connection + let new_connections = new_fann.get_connections(); + for connection in expected_connections.iter() { + let matching_connection = new_connections + .iter() + .find(|&c| c.from_neuron == connection.0 && c.to_neuron == connection.1); + if let Some(matching_connection) = matching_connection { + assert_eq!( + matching_connection.weight, connection.2, + "Connection: {:?}", + matching_connection + ); + } else { + assert!(false, "Connection not found: {:?}", connection); + } + } + + let expected_bias_neuron_connections = vec![ + // (from_neuron, to_neuron, weight) + // Bias Neurons + // Layer 2: bias neuron_id 4 + (4, 5, -405.0), + (4, 6, -406.0), + (4, 7, -407.0), + (4, 8, 408.0), + (4, 9, 409.0), + (4, 10, 410.0), + (4, 11, 411.0), + (4, 12, 412.0), + // Layer 3: bias neuron_id 13 + (13, 14, -809.0), + (13, 15, -810.0), + (13, 16, -811.0), + (13, 17, 1314.0), + (13, 18, 1315.0), + (13, 19, 1316.0), + (13, 20, 1317.0), + (13, 21, 1318.0), + (13, 22, 1319.0), + // Layer 4: bias neuron_id 23 + (23, 24, 2021.0), + (23, 25, 2022.0), + (23, 26, 2023.0), + (23, 27, 2024.0), + ]; + + for connection in expected_bias_neuron_connections.iter() { + let matching_connection = new_connections + .iter() + .find(|&c| c.from_neuron == connection.0 && c.to_neuron == connection.1); + if let Some(matching_connection) = matching_connection { + assert_eq!( + matching_connection.weight, connection.2, + "Connection: {:?}", + matching_connection + ); + } else { + assert!(false, "Connection not found: {:?}", connection); + } + } + + Ok(()) + } +} diff --git a/gemla/src/bin/test_state/mod.rs b/gemla/src/bin/test_state/mod.rs index fac0305..5bde119 100644 --- a/gemla/src/bin/test_state/mod.rs +++ b/gemla/src/bin/test_state/mod.rs @@ -1,6 +1,11 @@ -use gemla::{core::genetic_node::GeneticNode, error::Error}; +use async_trait::async_trait; +use gemla::{ + core::genetic_node::{GeneticNode, GeneticNodeContext}, + error::Error, +}; use rand::prelude::*; use serde::{Deserialize, Serialize}; +use uuid::Uuid; const POPULATION_SIZE: u64 = 5; const POPULATION_REDUCTION_SIZE: u64 = 3; @@ -8,20 +13,30 @@ const POPULATION_REDUCTION_SIZE: u64 = 3; #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TestState { pub population: Vec, + pub max_generations: u64, } +#[async_trait] impl GeneticNode for TestState { - fn initialize() -> Result, Error> { + type Context = (); + + async fn initialize(_context: GeneticNodeContext) -> Result, Error> { let mut population: Vec = vec![]; for _ in 0..POPULATION_SIZE { population.push(thread_rng().gen_range(0..100)) } - Ok(Box::new(TestState { population })) + Ok(Box::new(TestState { + population, + max_generations: 10, + })) } - fn simulate(&mut self) -> Result<(), Error> { + async fn simulate( + &mut self, + context: GeneticNodeContext, + ) -> Result { let mut rng = thread_rng(); self.population = self @@ -30,10 +45,14 @@ impl GeneticNode for TestState { .map(|p| p.saturating_add(rng.gen_range(-1..2))) .collect(); - Ok(()) + if context.generation >= self.max_generations { + Ok(false) + } else { + Ok(true) + } } - fn mutate(&mut self) -> Result<(), Error> { + async fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> { let mut rng = thread_rng(); let mut v = self.population.clone(); @@ -71,7 +90,12 @@ impl GeneticNode for TestState { Ok(()) } - fn merge(left: &TestState, right: &TestState) -> Result, Error> { + async fn merge( + left: &TestState, + right: &TestState, + id: &Uuid, + gemla_context: Self::Context, + ) -> Result, Error> { let mut v = left.population.clone(); v.append(&mut right.population.clone()); @@ -80,9 +104,18 @@ impl GeneticNode for TestState { v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); - let mut result = TestState { population: v }; + let mut result = TestState { + population: v, + max_generations: 10, + }; - result.mutate()?; + result + .mutate(GeneticNodeContext { + id: *id, + generation: 0, + gemla_context, + }) + .await?; Ok(Box::new(result)) } @@ -93,57 +126,97 @@ mod tests { use super::*; use gemla::core::genetic_node::GeneticNode; - #[test] - fn test_initialize() { - let state = TestState::initialize().unwrap(); + #[tokio::test] + async fn test_initialize() { + let state = TestState::initialize(GeneticNodeContext { + id: Uuid::new_v4(), + generation: 0, + gemla_context: (), + }) + .await + .unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } - #[test] - fn test_simulate() { + #[tokio::test] + async fn test_simulate() { let mut state = TestState { population: vec![1, 1, 2, 3], + max_generations: 1, }; let original_population = state.population.clone(); - state.simulate().unwrap(); + state + .simulate(GeneticNodeContext { + id: Uuid::new_v4(), + generation: 0, + gemla_context: (), + }) + .await + .unwrap(); assert!(original_population .iter() .zip(state.population.iter()) .all(|(&a, &b)| b >= a - 1 && b <= a + 2)); - state.simulate().unwrap(); - state.simulate().unwrap(); + state + .simulate(GeneticNodeContext { + id: Uuid::new_v4(), + generation: 0, + gemla_context: (), + }) + .await + .unwrap(); + state + .simulate(GeneticNodeContext { + id: Uuid::new_v4(), + generation: 0, + gemla_context: (), + }) + .await + .unwrap(); assert!(original_population .iter() .zip(state.population.iter()) .all(|(&a, &b)| b >= a - 3 && b <= a + 6)) } - #[test] - fn test_mutate() { + #[tokio::test] + async fn test_mutate() { let mut state = TestState { population: vec![4, 3, 3], + max_generations: 1, }; - state.mutate().unwrap(); + state + .mutate(GeneticNodeContext { + id: Uuid::new_v4(), + generation: 0, + gemla_context: (), + }) + .await + .unwrap(); assert_eq!(state.population.len(), POPULATION_SIZE as usize); } - #[test] - fn test_merge() { + #[tokio::test] + async fn test_merge() { let state1 = TestState { population: vec![1, 2, 4, 5], + max_generations: 1, }; let state2 = TestState { population: vec![0, 1, 3, 7], + max_generations: 1, }; - let merged_state = TestState::merge(&state1, &state2).unwrap(); + let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()) + .await + .unwrap(); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert!(merged_state.population.iter().any(|&x| x == 7)); diff --git a/gemla/src/constants/args.rs b/gemla/src/constants/args.rs deleted file mode 100644 index d833fd1..0000000 --- a/gemla/src/constants/args.rs +++ /dev/null @@ -1,2 +0,0 @@ -/// Corresponds to the FILE command line argument used in accordance with the clap crate. -pub const FILE: &str = "FILE"; diff --git a/gemla/src/constants/mod.rs b/gemla/src/constants/mod.rs deleted file mode 100644 index 6e10f4a..0000000 --- a/gemla/src/constants/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod args; diff --git a/gemla/src/core/genetic_node.rs b/gemla/src/core/genetic_node.rs index aeb1b3a..abcb15f 100644 --- a/gemla/src/core/genetic_node.rs +++ b/gemla/src/core/genetic_node.rs @@ -5,7 +5,9 @@ use crate::error::Error; use anyhow::Context; -use serde::{Deserialize, Serialize}; +use async_trait::async_trait; +use log::info; +use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::fmt::Debug; use uuid::Uuid; @@ -24,45 +26,65 @@ pub enum GeneticState { Finish, } +#[derive(Clone, Debug)] +pub struct GeneticNodeContext { + pub generation: u64, + pub id: Uuid, + pub gemla_context: S, +} + /// A trait used to interact with the internal state of nodes within the [`Bracket`] /// /// [`Bracket`]: crate::bracket::Bracket -pub trait GeneticNode { +#[async_trait] +pub trait GeneticNode: Send { + type Context; + /// Initializes a new instance of a [`GeneticState`]. /// /// # Examples /// TODO - fn initialize() -> Result, Error>; + async fn initialize(context: GeneticNodeContext) -> Result, Error>; - fn simulate(&mut self) -> Result<(), Error>; + async fn simulate(&mut self, context: GeneticNodeContext) + -> Result; /// Mutates members in a population and/or crossbreeds them to produce new offspring. /// /// # Examples /// TODO - fn mutate(&mut self) -> Result<(), Error>; + async fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>; - fn merge(left: &Self, right: &Self) -> Result, Error>; + async fn merge( + left: &Self, + right: &Self, + id: &Uuid, + context: Self::Context, + ) -> Result, Error>; } /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// well as signal recovery. Transition states are given by [`GeneticState`] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct GeneticNodeWrapper { +pub struct GeneticNodeWrapper +where + T: Clone, +{ node: Option, state: GeneticState, generation: u64, - max_generations: u64, id: Uuid, } -impl Default for GeneticNodeWrapper { +impl Default for GeneticNodeWrapper +where + T: Clone, +{ fn default() -> Self { GeneticNodeWrapper { node: None, state: GeneticState::Initialize, generation: 1, - max_generations: 1, id: Uuid::new_v4(), } } @@ -70,21 +92,20 @@ impl Default for GeneticNodeWrapper { impl GeneticNodeWrapper where - T: GeneticNode + Debug, + T: GeneticNode + Debug + Send + Clone, + T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { - pub fn new(max_generations: u64) -> Self { + pub fn new() -> Self { GeneticNodeWrapper:: { - max_generations, ..Default::default() } } - pub fn from(data: T, max_generations: u64, id: Uuid) -> Self { + pub fn from(data: T, id: Uuid) -> Self { GeneticNodeWrapper { node: Some(data), state: GeneticState::Simulate, generation: 1, - max_generations, id, } } @@ -93,36 +114,51 @@ where self.node.as_ref() } + pub fn take(&mut self) -> Option { + self.node.take() + } + pub fn id(&self) -> Uuid { self.id } - pub fn max_generations(&self) -> u64 { - self.max_generations + pub fn generation(&self) -> u64 { + self.generation } pub fn state(&self) -> GeneticState { self.state } - pub fn process_node(&mut self) -> Result { + pub async fn process_node(&mut self, gemla_context: T::Context) -> Result { + let context = GeneticNodeContext { + generation: self.generation, + id: self.id, + gemla_context, + }; + match (self.state, &mut self.node) { (GeneticState::Initialize, _) => { - self.node = Some(*T::initialize()?); + self.node = Some(*T::initialize(context.clone()).await?); self.state = GeneticState::Simulate; } (GeneticState::Simulate, Some(n)) => { - n.simulate() + let next_generation = n + .simulate(context.clone()) + .await .with_context(|| format!("Error simulating node: {:?}", self))?; - self.state = if self.generation >= self.max_generations { - GeneticState::Finish - } else { + info!("Simulation complete and continuing: {:?}", next_generation); + + self.state = if next_generation { GeneticState::Mutate + } else { + GeneticState::Finish }; } (GeneticState::Mutate, Some(n)) => { - n.mutate() + n.mutate(context.clone()) + .await .with_context(|| format!("Error mutating node: {:?}", self))?; self.generation += 1; @@ -141,40 +177,64 @@ mod tests { use super::*; use crate::error::Error; use anyhow::anyhow; + use async_trait::async_trait; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] struct TestState { pub score: f64, + pub max_generations: u64, } + #[async_trait] impl GeneticNode for TestState { - fn simulate(&mut self) -> Result<(), Error> { + type Context = (); + + async fn simulate( + &mut self, + context: GeneticNodeContext, + ) -> Result { self.score += 1.0; + if context.generation >= self.max_generations { + Ok(false) + } else { + Ok(true) + } + } + + async fn mutate( + &mut self, + _context: GeneticNodeContext, + ) -> Result<(), Error> { Ok(()) } - fn mutate(&mut self) -> Result<(), Error> { - Ok(()) + async fn initialize( + _context: GeneticNodeContext, + ) -> Result, Error> { + Ok(Box::new(TestState { + score: 0.0, + max_generations: 2, + })) } - fn initialize() -> Result, Error> { - Ok(Box::new(TestState { score: 0.0 })) - } - - fn merge(_l: &TestState, _r: &TestState) -> Result, Error> { + async fn merge( + _l: &TestState, + _r: &TestState, + _id: &Uuid, + _: Self::Context, + ) -> Result, Error> { Err(Error::Other(anyhow!("Unable to merge"))) } } #[test] fn test_new() -> Result<(), Error> { - let genetic_node = GeneticNodeWrapper::::new(10); + let genetic_node = GeneticNodeWrapper::::new(); let other_genetic_node = GeneticNodeWrapper:: { node: None, state: GeneticState::Initialize, generation: 1, - max_generations: 10, id: genetic_node.id(), }; @@ -185,15 +245,17 @@ mod tests { #[test] fn test_from() -> Result<(), Error> { - let val = TestState { score: 0.0 }; + let val = TestState { + score: 0.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let other_genetic_node = GeneticNodeWrapper:: { node: Some(val), state: GeneticState::Simulate, generation: 1, - max_generations: 10, id: genetic_node.id(), }; @@ -204,9 +266,12 @@ mod tests { #[test] fn test_as_ref() -> Result<(), Error> { - let val = TestState { score: 3.0 }; + let val = TestState { + score: 3.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let ref_value = genetic_node.as_ref().unwrap(); @@ -217,9 +282,12 @@ mod tests { #[test] fn test_id() -> Result<(), Error> { - let val = TestState { score: 3.0 }; + let val = TestState { + score: 3.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let id_value = genetic_node.id(); @@ -228,24 +296,14 @@ mod tests { Ok(()) } - #[test] - fn test_max_generations() -> Result<(), Error> { - let val = TestState { score: 3.0 }; - let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); - - let max_generations = genetic_node.max_generations(); - - assert_eq!(max_generations, 10); - - Ok(()) - } - #[test] fn test_state() -> Result<(), Error> { - let val = TestState { score: 3.0 }; + let val = TestState { + score: 3.0, + max_generations: 10, + }; let uuid = Uuid::new_v4(); - let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid); + let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let state = genetic_node.state(); @@ -254,16 +312,16 @@ mod tests { Ok(()) } - #[test] - fn test_process_node() -> Result<(), Error> { - let mut genetic_node = GeneticNodeWrapper::::new(2); + #[tokio::test] + async fn test_process_node() -> Result<(), Error> { + let mut genetic_node = GeneticNodeWrapper::::new(); assert_eq!(genetic_node.state(), GeneticState::Initialize); - assert_eq!(genetic_node.process_node()?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node()?, GeneticState::Mutate); - assert_eq!(genetic_node.process_node()?, GeneticState::Simulate); - assert_eq!(genetic_node.process_node()?, GeneticState::Finish); - assert_eq!(genetic_node.process_node()?, GeneticState::Finish); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish); + assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish); Ok(()) } diff --git a/gemla/src/core/mod.rs b/gemla/src/core/mod.rs index 7564423..d3de3f3 100644 --- a/gemla/src/core/mod.rs +++ b/gemla/src/core/mod.rs @@ -4,42 +4,44 @@ pub mod genetic_node; use crate::{error::Error, tree::Tree}; -use file_linked::FileLinked; -use futures::{future, future::BoxFuture}; +use async_recursion::async_recursion; +use file_linked::{constants::data_format::DataFormat, FileLinked}; +use futures::future; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use log::{info, trace, warn}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use std::{ collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, - time::Instant, + sync::Arc, time::Instant, }; +use tokio::{sync::RwLock, task::JoinHandle}; use uuid::Uuid; type SimulationTree = Box>>; /// Provides configuration options for managing a [`Gemla`] object as it executes. -/// +/// /// # Examples -/// ``` +/// ```rust,ignore /// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] /// struct TestState { /// pub score: f64, /// } -/// +/// /// impl genetic_node::GeneticNode for TestState { /// fn simulate(&mut self) -> Result<(), Error> { /// self.score += 1.0; /// Ok(()) /// } -/// +/// /// fn mutate(&mut self) -> Result<(), Error> { /// Ok(()) /// } -/// +/// /// fn initialize() -> Result, Error> { /// Ok(Box::new(TestState { score: 0.0 })) /// } -/// +/// /// fn merge(left: &TestState, right: &TestState) -> Result, Error> { /// Ok(Box::new(if left.score > right.score { /// left.clone() @@ -48,14 +50,13 @@ type SimulationTree = Box>>; /// })) /// } /// } -/// +/// /// fn main() { /// /// } /// ``` #[derive(Serialize, Deserialize, Copy, Clone)] pub struct GemlaConfig { - pub generations_per_node: u64, pub overwrite: bool, } @@ -65,79 +66,125 @@ pub struct GemlaConfig { /// individuals. /// /// [`GeneticNode`]: genetic_node::GeneticNode -pub struct Gemla<'a, T> +pub struct Gemla where - T: Serialize + Clone, + T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Clone, + T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { - pub data: FileLinked<(Option>, GemlaConfig)>, - threads: HashMap, Error>>>, + pub data: FileLinked<(Option>, GemlaConfig, T::Context)>, + threads: HashMap, Error>>>, } -impl<'a, T: 'a> Gemla<'a, T> +impl Gemla where - T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send, + T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone, + T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default, { - pub fn new(path: &Path, config: GemlaConfig) -> Result { + pub async fn new( + path: &Path, + config: GemlaConfig, + data_format: DataFormat, + ) -> Result { match File::open(path) { - // If the file exists we either want to overwrite the file or read from the file + // If the file exists we either want to overwrite the file or read from the file // based on the configuration provided Ok(_) => Ok(Gemla { data: if config.overwrite { - FileLinked::new((None, config), path)? + FileLinked::new((None, config, T::Context::default()), path, data_format) + .await? } else { - FileLinked::from_file(path)? + FileLinked::from_file(path, data_format)? }, threads: HashMap::new(), }), // If the file doesn't exist we must create it Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { - data: FileLinked::new((None, config), path)?, + data: FileLinked::new((None, config, T::Context::default()), path, data_format) + .await?, threads: HashMap::new(), }), Err(error) => Err(Error::IO(error)), } } - pub fn tree_ref(&self) -> Option<&SimulationTree> { - self.data.readonly().0.as_ref() + pub fn tree_ref(&self) -> Arc>, GemlaConfig, T::Context)>> { + self.data.readonly().clone() } pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> { - // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed - // in the tree and which nodes have not. - self.data.mutate(|(d, c)| { - let mut tree: Option> = Gemla::increase_height(d.take(), c, steps); - mem::swap(d, &mut tree); - })?; + let tree_completed = { + // Only increase height if the tree is uninitialized or completed + let data_arc = self.data.readonly(); + let data_ref = data_arc.read().await; + let tree_ref = data_ref.0.as_ref(); - info!( - "Height of simulation tree increased to {}", - self.tree_ref() - .map(|t| format!("{}", t.height())) - .unwrap_or_else(|| "Tree is not defined".to_string()) - ); + tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true) + }; + + if tree_completed { + // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed + // in the tree and which nodes have not. + self.data + .mutate(|(d, _, _)| { + let mut tree: Option> = + Gemla::increase_height(d.take(), steps); + mem::swap(d, &mut tree); + }) + .await?; + } + + { + // Only increase height if the tree is uninitialized or completed + let data_arc = self.data.readonly(); + let data_ref = data_arc.read().await; + let tree_ref = data_ref.0.as_ref(); + + info!( + "Height of simulation tree increased to {}", + tree_ref + .map(|t| format!("{}", t.height())) + .unwrap_or_else(|| "Tree is not defined".to_string()) + ); + } loop { - // We need to keep simulating until the tree has been completely processed. - if self - .tree_ref() - .map(|t| Gemla::is_completed(t)) - .unwrap_or(false) + let is_tree_processed; + { + let data_arc = self.data.readonly(); + let data_ref = data_arc.read().await; + let tree_ref = data_ref.0.as_ref(); + + is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false) + } + + // We need to keep simulating until the tree has been completely processed. + if is_tree_processed { self.join_threads().await?; info!("Processed tree"); break; } - if let Some(node) = self - .tree_ref() - .and_then(|t| self.get_unprocessed_node(t)) - { + let (node, gemla_context) = { + let data_arc = self.data.readonly(); + let data_ref = data_arc.read().await; + let (tree_ref, _, gemla_context) = &*data_ref; // (Option>>, GemlaConfig, T::Context) + + let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t)); + + (node, gemla_context.clone()) + }; + + if let Some(node) = node { trace!("Adding node to process list {}", node.id()); - self.threads - .insert(node.id(), Box::pin(Gemla::process_node(node))); + let gemla_context = gemla_context.clone(); + + self.threads.insert( + node.id(), + tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }), + ); } else { trace!("No node found to process, joining threads"); @@ -153,38 +200,56 @@ where trace!("Joining threads for nodes {:?}", self.threads.keys()); let results = future::join_all(self.threads.values_mut()).await; + // Converting a list of results into a result wrapping the list let reduced_results: Result>, Error> = - results.into_iter().collect(); + results.into_iter().flatten().collect(); self.threads.clear(); // We need to retrieve the processed nodes from the resulting list and replace them in the original list - reduced_results.and_then(|r| { - self.data.mutate(|(d, _)| { - if let Some(t) = d { - let failed_nodes = Gemla::replace_nodes(t, r); - // We receive a list of nodes that were unable to be found in the original tree - if !failed_nodes.is_empty() { - warn!( - "Unable to find {:?} to replace in tree", - failed_nodes.iter().map(|n| n.id()) - ) - } + match reduced_results { + Ok(r) => { + self.data + .mutate_async(|d| async move { + // Scope to limit the duration of the read lock + let (_, context) = { + let data_read = d.read().await; + (data_read.1, data_read.2.clone()) + }; // Read lock is dropped here - // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes - Gemla::merge_completed_nodes(t) - } else { - warn!("Unable to replce nodes {:?} in empty tree", r); - Ok(()) - } - })? - })?; + let mut data_write = d.write().await; + + if let Some(t) = data_write.0.as_mut() { + let failed_nodes = Gemla::replace_nodes(t, r); + // We receive a list of nodes that were unable to be found in the original tree + if !failed_nodes.is_empty() { + warn!( + "Unable to find {:?} to replace in tree", + failed_nodes.iter().map(|n| n.id()) + ) + } + + // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes + Gemla::merge_completed_nodes(t, context.clone()).await + } else { + warn!("Unable to replce nodes {:?} in empty tree", r); + Ok(()) + } + }) + .await??; + } + Err(e) => return Err(e), + } } Ok(()) } - fn merge_completed_nodes(tree: &mut SimulationTree) -> Result<(), Error> { + #[async_recursion] + async fn merge_completed_nodes<'a>( + tree: &'a mut SimulationTree, + gemla_context: T::Context, + ) -> Result<(), Error> { if tree.val.state() == GeneticState::Initialize { match (&mut tree.left, &mut tree.right) { // If the current node has been initialized, and has children nodes that are completed, then we need @@ -195,43 +260,37 @@ where { info!("Merging nodes {} and {}", l.val.id(), r.val.id()); if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) { - let merged_node = GeneticNode::merge(left_node, right_node)?; - tree.val = GeneticNodeWrapper::from( - *merged_node, - tree.val.max_generations(), - tree.val.id(), - ); + let merged_node = GeneticNode::merge( + left_node, + right_node, + &tree.val.id(), + gemla_context.clone(), + ) + .await?; + tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.id()); } } (Some(l), Some(r)) => { - Gemla::merge_completed_nodes(l)?; - Gemla::merge_completed_nodes(r)?; + Gemla::merge_completed_nodes(l, gemla_context.clone()).await?; + Gemla::merge_completed_nodes(r, gemla_context.clone()).await?; } // If there is only one child node that's completed then we want to copy it to the parent node (Some(l), None) if l.val.state() == GeneticState::Finish => { trace!("Copying node {}", l.val.id()); if let Some(left_node) = l.val.as_ref() { - GeneticNodeWrapper::from( - left_node.clone(), - tree.val.max_generations(), - tree.val.id(), - ); + GeneticNodeWrapper::from(left_node.clone(), tree.val.id()); } } - (Some(l), None) => Gemla::merge_completed_nodes(l)?, + (Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?, (None, Some(r)) if r.val.state() == GeneticState::Finish => { trace!("Copying node {}", r.val.id()); if let Some(right_node) = r.val.as_ref() { - tree.val = GeneticNodeWrapper::from( - right_node.clone(), - tree.val.max_generations(), - tree.val.id(), - ); + tree.val = GeneticNodeWrapper::from(right_node.clone(), tree.val.id()); } } - (None, Some(r)) => Gemla::merge_completed_nodes(r)?, + (None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?, (_, _) => (), } } @@ -240,15 +299,18 @@ where } fn get_unprocessed_node(&self, tree: &SimulationTree) -> Option> { - // If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list + // If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list // should be fine because we process the tree from bottom to top. if tree.val.state() != GeneticState::Finish && !self.threads.contains_key(&tree.val.id()) { match (&tree.left, &tree.right) { - // If the children are finished we can start processing the currrent node. The current node should be merged from the children already + // If the children are finished we can start processing the currrent node. The current node should be merged from the children already // during join_threads. (Some(l), Some(r)) if l.val.state() == GeneticState::Finish - && r.val.state() == GeneticState::Finish => Some(tree.val.clone()), + && r.val.state() == GeneticState::Finish => + { + Some(tree.val.clone()) + } (Some(l), Some(r)) => self .get_unprocessed_node(l) .or_else(|| self.get_unprocessed_node(r)), @@ -278,25 +340,19 @@ where } } - fn increase_height( - tree: Option>, - config: &GemlaConfig, - amount: u64, - ) -> Option> { + fn increase_height(tree: Option>, amount: u64) -> Option> { if amount == 0 { tree } else { - let left_branch_right = + let left_branch_height = tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1; - + Some(Box::new(Tree::new( - GeneticNodeWrapper::new(config.generations_per_node), - Gemla::increase_height(tree, config, amount - 1), + GeneticNodeWrapper::new(), + Gemla::increase_height(tree, amount - 1), // The right branch height has to equal the left branches total height - if left_branch_right > 0 { - Some(Box::new(btree!(GeneticNodeWrapper::new( - left_branch_right * config.generations_per_node - )))) + if left_branch_height > 0 { + Some(Box::new(btree!(GeneticNodeWrapper::new()))) } else { None }, @@ -306,16 +362,19 @@ where fn is_completed(tree: &SimulationTree) -> bool { // If the current node is finished, then by convention the children should all be finished as well - tree.val.state() == GeneticState::Finish + tree.val.state() == GeneticState::Finish } - async fn process_node(mut node: GeneticNodeWrapper) -> Result, Error> { + async fn process_node( + mut node: GeneticNodeWrapper, + gemla_context: T::Context, + ) -> Result, Error> { let node_state_time = Instant::now(); let node_state = node.state(); - node.process_node()?; + node.process_node(gemla_context.clone()).await?; - trace!( + info!( "{:?} completed in {:?} for {}", node_state, node_state_time.elapsed(), @@ -333,9 +392,13 @@ where #[cfg(test)] mod tests { use crate::core::*; + use async_trait::async_trait; use serde::{Deserialize, Serialize}; - use std::path::PathBuf; use std::fs; + use std::path::PathBuf; + use tokio::runtime::Runtime; + + use self::genetic_node::GeneticNodeContext; struct CleanUp { path: PathBuf, @@ -364,23 +427,43 @@ mod tests { #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] struct TestState { pub score: f64, + pub max_generations: u64, } + #[async_trait] impl genetic_node::GeneticNode for TestState { - fn simulate(&mut self) -> Result<(), Error> { + type Context = (); + + async fn simulate( + &mut self, + context: GeneticNodeContext, + ) -> Result { self.score += 1.0; + Ok(context.generation < self.max_generations) + } + + async fn mutate( + &mut self, + _context: GeneticNodeContext, + ) -> Result<(), Error> { Ok(()) } - fn mutate(&mut self) -> Result<(), Error> { - Ok(()) + async fn initialize( + _context: GeneticNodeContext, + ) -> Result, Error> { + Ok(Box::new(TestState { + score: 0.0, + max_generations: 10, + })) } - fn initialize() -> Result, Error> { - Ok(Box::new(TestState { score: 0.0 })) - } - - fn merge(left: &TestState, right: &TestState) -> Result, Error> { + async fn merge( + left: &TestState, + right: &TestState, + _id: &Uuid, + _: Self::Context, + ) -> Result, Error> { Ok(Box::new(if left.score > right.score { left.clone() } else { @@ -389,66 +472,93 @@ mod tests { } } - #[test] - fn test_new() -> Result<(), Error> { + #[tokio::test] + async fn test_new() -> Result<(), Error> { let path = PathBuf::from("test_new_non_existing"); - CleanUp::new(&path).run(|p| { - assert!(!path.exists()); + // Use `spawn_blocking` to run synchronous code that needs to call async code internally. + tokio::task::spawn_blocking(move || { + let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. + CleanUp::new(&path).run(move |p| { + rt.block_on(async { + assert!(!path.exists()); - // Testing initial creation - let mut config = GemlaConfig { - generations_per_node: 1, - overwrite: true - }; - let mut gemla = Gemla::::new(&p, config)?; + // Testing initial creation + let mut config = GemlaConfig { overwrite: true }; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; - smol::block_on(gemla.simulate(2))?; - assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2); - - drop(gemla); - assert!(path.exists()); + // Now we can use `.await` within the spawned blocking task. + gemla.simulate(2).await?; + let data = gemla.data.readonly(); + let data_lock = data.read().await; + assert_eq!(data_lock.0.as_ref().unwrap().height(), 2); - // Testing overwriting data - let mut gemla = Gemla::::new(&p, config)?; + drop(data_lock); + drop(gemla); + assert!(path.exists()); - smol::block_on(gemla.simulate(2))?; - assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2); + // Testing overwriting data + let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; - drop(gemla); - assert!(path.exists()); + gemla.simulate(2).await?; + let data = gemla.data.readonly(); + let data_lock = data.read().await; + assert_eq!(data_lock.0.as_ref().unwrap().height(), 2); - // Testing not-overwriting data - config.overwrite = false; - let mut gemla = Gemla::::new(&p, config)?; + drop(data_lock); + drop(gemla); + assert!(path.exists()); - smol::block_on(gemla.simulate(2))?; - assert_eq!(gemla.tree_ref().unwrap().height(), 4); + // Testing not-overwriting data + config.overwrite = false; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; - drop(gemla); - assert!(path.exists()); + gemla.simulate(2).await?; + let data = gemla.data.readonly(); + let data_lock = data.read().await; + let tree = data_lock.0.as_ref().unwrap(); + assert_eq!(tree.height(), 4); - Ok(()) + drop(data_lock); + drop(gemla); + assert!(path.exists()); + + Ok(()) + }) + }) }) + .await + .unwrap()?; // Wait for the blocking task to complete, then handle the Result. + + Ok(()) } - #[test] - fn test_simulate() -> Result<(), Error> { + #[tokio::test] + async fn test_simulate() -> Result<(), Error> { let path = PathBuf::from("test_simulate"); - CleanUp::new(&path).run(|p| { - // Testing initial creation - let config = GemlaConfig { - generations_per_node: 10, - overwrite: true - }; - let mut gemla = Gemla::::new(&p, config)?; + // Use `spawn_blocking` to run the synchronous closure that internally awaits async code. + tokio::task::spawn_blocking(move || { + let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block. + CleanUp::new(&path).run(move |p| { + rt.block_on(async { + // Testing initial creation + let config = GemlaConfig { overwrite: true }; + let mut gemla = Gemla::::new(&p, config, DataFormat::Json).await?; - smol::block_on(gemla.simulate(5))?; - let tree = gemla.tree_ref().unwrap(); - assert_eq!(tree.height(), 5); - assert_eq!(tree.val.as_ref().unwrap().score, 50.0); + // Now we can use `.await` within the spawned blocking task. + gemla.simulate(5).await?; + let data = gemla.data.readonly(); + let data_lock = data.read().await; + let tree = data_lock.0.as_ref().unwrap(); + assert_eq!(tree.height(), 5); + assert_eq!(tree.val.as_ref().unwrap().score, 50.0); - Ok(()) + Ok(()) + }) + }) }) - } + .await + .unwrap()?; // Wait for the blocking task to complete, then handle the Result. + Ok(()) + } } diff --git a/gemla/src/lib.rs b/gemla/src/lib.rs index 77d1371..d69f277 100644 --- a/gemla/src/lib.rs +++ b/gemla/src/lib.rs @@ -1,5 +1,4 @@ #[macro_use] pub mod tree; -pub mod constants; pub mod core; pub mod error; diff --git a/gemla/src/tree/mod.rs b/gemla/src/tree/mod.rs index c1a2b39..1388aaf 100644 --- a/gemla/src/tree/mod.rs +++ b/gemla/src/tree/mod.rs @@ -36,7 +36,7 @@ use std::cmp::max; /// t.right = Some(Box::new(btree!(3))); /// assert_eq!(t.right.unwrap().val, 3); /// ``` -#[derive(Default, Serialize, Deserialize, Clone, PartialEq, Debug)] +#[derive(Default, Serialize, Deserialize, PartialEq, Debug)] pub struct Tree { pub val: T, pub left: Option>>, diff --git a/parameter_analysis.py b/parameter_analysis.py new file mode 100644 index 0000000..f3993fb --- /dev/null +++ b/parameter_analysis.py @@ -0,0 +1,380 @@ +# Re-importing necessary libraries +import json +import matplotlib.pyplot as plt +from collections import defaultdict +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.colors as mcolors +import matplotlib.cm as cm +import matplotlib.ticker as ticker + +# Simplified JSON data for demonstration +with open('gemla/round4.json', 'r') as file: + simplified_json_data = json.load(file) + +# Function to traverse the tree to find a node id +def traverse_right_nodes(node): + if node is None: + return [] + + right_node = node.get("right") + left_node = node.get("left") + + if right_node is None and left_node is None: + return [] + elif right_node and left_node: + return [right_node] + traverse_right_nodes(left_node) + + return [] + +# Getting most recent right graph +right_nodes = traverse_right_nodes(simplified_json_data[0]) + +# Heatmaps +# Data structure to store mutation rates, generations, and scores +mutation_rate_data = defaultdict(lambda: defaultdict(list)) + +# Populate the dictionary with scores indexed by mutation rate and generation +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + minor_mutation_rate = node_val["minor_mutation_rate"] + generation = node_val["generation"] + # Ensure each score is associated with the correct generation + for gen_index, score_list in enumerate(scores): + for score in score_list.values(): + mutation_rate_data[minor_mutation_rate][gen_index].append(score) + +# Prepare data for heatmap +max_generation = max(max(gens.keys()) for gens in mutation_rate_data.values()) +heatmap_data = np.full((len(mutation_rate_data), max_generation + 1), np.nan) + +# Populate the heatmap data with average scores +mutation_rates = sorted(mutation_rate_data.keys()) +for i, mutation_rate in enumerate(mutation_rates): + for generation in range(max_generation + 1): + scores = mutation_rate_data[mutation_rate][generation] + if scores: # Check if there are scores for this generation + heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the heatmap +df_heatmap = pd.DataFrame( + data=heatmap_data, + index=mutation_rates, + columns=range(max_generation + 1) +) + +# Data structure to store major mutation rates, generations, and scores +major_mutation_rate_data = defaultdict(lambda: defaultdict(list)) + +# Populate the dictionary with scores indexed by major mutation rate and generation +# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + major_mutation_rate = node_val["major_mutation_rate"] + generation = node_val["generation"] + for gen_index, score_list in enumerate(scores): + for score in score_list.values(): + major_mutation_rate_data[major_mutation_rate][gen_index].append(score) + +# Prepare the heatmap data for major_mutation_rate similar to minor_mutation_rate +major_heatmap_data = np.full((len(major_mutation_rate_data), max_generation + 1), np.nan) +major_mutation_rates = sorted(major_mutation_rate_data.keys()) + +for i, major_rate in enumerate(major_mutation_rates): + for generation in range(max_generation + 1): + scores = major_mutation_rate_data[major_rate][generation] + if scores: # Check if there are scores for this generation + major_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_major_heatmap = pd.DataFrame( + data=major_heatmap_data, + index=major_mutation_rates, + columns=range(max_generation + 1) +) + +# crossbreed_segments +# Data structure to store major mutation rates, generations, and scores +crossbreed_segments_data = defaultdict(lambda: defaultdict(list)) + +# Populate the dictionary with scores indexed by major mutation rate and generation +# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + crossbreed_segments = node_val["crossbreed_segments"] + generation = node_val["generation"] + for gen_index, score_list in enumerate(scores): + for score in score_list.values(): + crossbreed_segments_data[crossbreed_segments][gen_index].append(score) + +# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate +crossbreed_heatmap_data = np.full((len(crossbreed_segments_data), max_generation + 1), np.nan) +crossbreed_segments = sorted(crossbreed_segments_data.keys()) + +for i, crossbreed_segment in enumerate(crossbreed_segments): + for generation in range(max_generation + 1): + scores = crossbreed_segments_data[crossbreed_segment][generation] + if scores: # Check if there are scores for this generation + crossbreed_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_crossbreed_heatmap = pd.DataFrame( + data=crossbreed_heatmap_data, + index=crossbreed_segments, + columns=range(max_generation + 1) +) + +# mutation_weight_range +# Data structure to store major mutation rates, generations, and scores +mutation_weight_range_data = defaultdict(lambda: defaultdict(list)) + +# Populate the dictionary with scores indexed by major mutation rate and generation +# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + mutation_weight_range = node_val["mutation_weight_range"] + positive_extent = mutation_weight_range["end"] + negative_extent = -mutation_weight_range["start"] + mutation_weight_range = (positive_extent + negative_extent) / 2 + generation = node_val["generation"] + for gen_index, score_list in enumerate(scores): + for score in score_list.values(): + mutation_weight_range_data[mutation_weight_range][gen_index].append(score) + +# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate +mutation_weight_range_heatmap_data = np.full((len(mutation_weight_range_data), max_generation + 1), np.nan) +mutation_weight_ranges = sorted(mutation_weight_range_data.keys()) + +for i, mutation_weight_range in enumerate(mutation_weight_ranges): + for generation in range(max_generation + 1): + scores = mutation_weight_range_data[mutation_weight_range][generation] + if scores: # Check if there are scores for this generation + mutation_weight_range_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_mutation_weight_range_heatmap = pd.DataFrame( + data=mutation_weight_range_heatmap_data, + index=mutation_weight_ranges, + columns=range(max_generation + 1) +) + +# weight_initialization_range +# Data structure to store major mutation rates, generations, and scores +weight_initialization_range_data = defaultdict(lambda: defaultdict(list)) + +# Populate the dictionary with scores indexed by major mutation rate and generation +# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + weight_initialization_range = node_val["weight_initialization_range"] + positive_extent = weight_initialization_range["end"] + negative_extent = -weight_initialization_range["start"] + weight_initialization_range = (positive_extent + negative_extent) / 2 + generation = node_val["generation"] + for gen_index, score_list in enumerate(scores): + for score in score_list.values(): + weight_initialization_range_data[weight_initialization_range][gen_index].append(score) + +# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate +weight_initialization_range_heatmap_data = np.full((len(weight_initialization_range_data), max_generation + 1), np.nan) +weight_initialization_ranges = sorted(weight_initialization_range_data.keys()) + +for i, weight_initialization_range in enumerate(weight_initialization_ranges): + for generation in range(max_generation + 1): + scores = weight_initialization_range_data[weight_initialization_range][generation] + if scores: # Check if there are scores for this generation + weight_initialization_range_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_weight_initialization_range_heatmap = pd.DataFrame( + data=weight_initialization_range_heatmap_data, + index=weight_initialization_ranges, + columns=range(max_generation + 1) +) + +# weight_initialization_range_skew +# Data structure to store major mutation rates, generations, and scores +weight_initialization_range_skew_data = defaultdict(lambda: defaultdict(list)) + +# Populate the dictionary with scores indexed by major mutation rate and generation +# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + weight_initialization_range = node_val["weight_initialization_range"] + positive_extent = weight_initialization_range["end"] + negative_extent = -weight_initialization_range["start"] + weight_initialization_range_skew = (positive_extent - negative_extent) / 2 + generation = node_val["generation"] + for gen_index, score_list in enumerate(scores): + for score in score_list.values(): + weight_initialization_range_skew_data[weight_initialization_range_skew][gen_index].append(score) + +# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate +weight_initialization_range_skew_heatmap_data = np.full((len(weight_initialization_range_skew_data), max_generation + 1), np.nan) +weight_initialization_range_skews = sorted(weight_initialization_range_skew_data.keys()) + +for i, weight_initialization_range_skew in enumerate(weight_initialization_range_skews): + for generation in range(max_generation + 1): + scores = weight_initialization_range_skew_data[weight_initialization_range_skew][generation] + if scores: # Check if there are scores for this generation + weight_initialization_range_skew_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_weight_initialization_range_skew_heatmap = pd.DataFrame( + data=weight_initialization_range_skew_heatmap_data, + index=weight_initialization_range_skews, + columns=range(max_generation + 1) +) + +# Analyze number of neurons correlation to score +# We can get the number of neurons via node_val["nn_shapes"] which contains an array of maps +# Each map has a key for the individual id and a value which is an array of integers representing the number of neurons in each layer +# We can use the individual id to get the score from the scores array +# We then generate a density map of the number of neurons vs the score +neuron_number_score_data = defaultdict(lambda: defaultdict(list)) + +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + nn_shapes = node_val["nn_shapes"] + # Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation) + for gen_index, score in enumerate(scores): + for individual_id, nn_shape in nn_shapes[gen_index].items(): + neuron_number = sum(nn_shape) + # check if score has a value for the individual id + if individual_id not in score: + continue + neuron_number_score_data[neuron_number][gen_index].append(score[individual_id]) + +# prepare the density map data +neuron_number_score_heatmap_data = np.full((len(neuron_number_score_data), max_generation + 1), np.nan) +neuron_numbers = sorted(neuron_number_score_data.keys()) + +for i, neuron_number in enumerate(neuron_numbers): + for generation in range(max_generation + 1): + scores = neuron_number_score_data[neuron_number][generation] + if scores: # Check if there are scores for this generation + neuron_number_score_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_neuron_number_score_heatmap = pd.DataFrame( + data=neuron_number_score_heatmap_data, + index=neuron_numbers, + columns=range(max_generation + 1) +) + +# Analyze number of layers correlation to score +nn_layers_score_data = defaultdict(lambda: defaultdict(list)) + +for node in right_nodes: + node_val = node["val"]["node"] + if node_val: + scores = node_val["scores"] + nn_shapes = node_val["nn_shapes"] + # Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation) + for gen_index, score in enumerate(scores): + for individual_id, nn_shape in nn_shapes[gen_index].items(): + layer_number = len(nn_shape) + # check if score has a value for the individual id + if individual_id not in score: + continue + nn_layers_score_data[layer_number][gen_index].append(score[individual_id]) + +# prepare the density map data +nn_layers_score_heatmap_data = np.full((len(nn_layers_score_data), max_generation + 1), np.nan) +nn_layers = sorted(nn_layers_score_data.keys()) + +for i, nn_layer in enumerate(nn_layers): + for generation in range(max_generation + 1): + scores = nn_layers_score_data[nn_layer][generation] + if scores: # Check if there are scores for this generation + nn_layers_score_heatmap_data[i, generation] = np.mean(scores) + +# Creating a DataFrame for the major mutation rate heatmap +df_nn_layers_score_heatmap = pd.DataFrame( + data=nn_layers_score_heatmap_data, + index=nn_layers, + columns=range(max_generation + 1) +) + +# print("Format: ", custom_formatter(0.123498761234, 0)) + +# Creating subplots +fig, axs = plt.subplots(2, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots + +# Plotting the minor mutation rate heatmap +sns.heatmap(df_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 0]) +# axs[0, 0].set_title('Minor Mutation Rate') +axs[0, 0].set_xlabel('Minor Mutation Rate') +axs[0, 0].set_ylabel('Generation') +axs[0, 0].invert_yaxis() + +# Plotting the major mutation rate heatmap +sns.heatmap(df_major_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 1]) +# axs[0, 1].set_title('Major Mutation Rate') +axs[0, 1].set_xlabel('Major Mutation Rate') +axs[0, 1].invert_yaxis() + +# Plotting the crossbreed_segments heatmap +sns.heatmap(df_crossbreed_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 0]) +# axs[1, 0].set_title('Crossbreed Segments') +axs[1, 0].set_xlabel('Crossbreed Segments') +axs[1, 0].set_ylabel('Generation') +axs[1, 0].invert_yaxis() + +# Plotting the mutation_weight_range heatmap +sns.heatmap(df_mutation_weight_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 1]) +# axs[1, 1].set_title('Mutation Weight Range') +axs[1, 1].set_xlabel('Mutation Weight Range') +axs[1, 1].invert_yaxis() + +fig3, axs3 = plt.subplots(1, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots + +# Plotting the weight_initialization_range heatmap +sns.heatmap(df_weight_initialization_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[0]) +# axs[2, 0].set_title('Weight Initialization Range') +axs3[0].set_xlabel('Weight Initialization Range') +axs3[0].set_ylabel('Generation') +axs3[0].invert_yaxis() + +# Plotting the weight_initialization_range_skew heatmap +sns.heatmap(df_weight_initialization_range_skew_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[1]) +# axs[2, 1].set_title('Weight Initialization Range Skew') +axs3[1].set_xlabel('Weight Initialization Range Skew') +axs3[1].set_ylabel('Generation') +axs3[1].invert_yaxis() + +# Creating a new window for the scatter plots +fig2, axs2 = plt.subplots(2, 1, figsize=(20, 14)) # Creates a 2x1 grid of subplots + +# Plotting the neuron number vs score heatmap +sns.heatmap(df_neuron_number_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[1]) +# axs[3, 1].set_title('Neuron Number vs. Score') +axs2[1].set_xlabel('Neuron Number') +axs2[1].set_ylabel('Generation') +axs2[1].invert_yaxis() + +# Plotting the number of layers vs score heatmap +sns.heatmap(df_nn_layers_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[0]) +# axs[3, 1].set_title('Number of Layers vs. Score') +axs2[0].set_xlabel('Number of Layers') +axs2[0].set_ylabel('Generation') +axs2[0].invert_yaxis() + +# Display the plot +plt.tight_layout() # Adjusts the subplots to fit into the figure area. +plt.show() \ No newline at end of file diff --git a/sbcl_spike/main.lisp b/sbcl_spike/main.lisp deleted file mode 100644 index cf6e008..0000000 --- a/sbcl_spike/main.lisp +++ /dev/null @@ -1,12 +0,0 @@ -;; Define a type that contains a population size and a population cutoff -(defclass simulation-node () ((population-size :initarg :population-size :accessor population-size) - (population-cutoff :initarg :population-cutoff :accessor population-cutoff) - (population :initform () :accessor population))) - -;; Define a method that initializes population-size number of children in a population each with a random value -(defmethod initialize-instance :after ((node simulation-node) &key) - (setf (population node) (make-list (population-size node) :initial-element (random 100)))) - -(let ((node (make-instance 'simulation-node :population-size 100 :population-cutoff 10))) - (print (population-size node)) - (population node)) \ No newline at end of file diff --git a/visualize_networks.py b/visualize_networks.py new file mode 100644 index 0000000..280011f --- /dev/null +++ b/visualize_networks.py @@ -0,0 +1,118 @@ +import matplotlib.pyplot as plt +import networkx as nx +import subprocess +import tkinter as tk +from tkinter import filedialog + +def select_file(): + root = tk.Tk() + root.withdraw() # Hide the main window + file_path = filedialog.askopenfilename( + initialdir="/", # Set the initial directory to search for files + title="Select file", + filetypes=(("Net files", "*.net"), ("All files", "*.*")) + ) + return file_path + +def get_fann_data(network_file): + # Adjust the path to the Rust executable as needed + result = subprocess.run(['./extract_fann_data/target/debug/extract_fann_data.exe', network_file], capture_output=True, text=True) + if result.returncode != 0: + print("Error:", result.stderr) + return None, None + + layer_sizes = [] + connections = [] + parsing_connections = False + + for line in result.stdout.splitlines(): + if line.startswith("Layers:"): + continue + elif line.startswith("Connections:"): + parsing_connections = True + continue + + if parsing_connections: + from_neuron, to_neuron, weight = map(float, line.split()) + connections.append((int(from_neuron), int(to_neuron), weight)) + else: + layer_size, bias_count = map(int, line.split()) + layer_sizes.append((layer_size, bias_count)) + + return layer_sizes, connections + +def visualize_fann_network(network_file): + # Get network data + layer_sizes, connections = get_fann_data(network_file) + if layer_sizes is None or connections is None: + return # Error handling in get_fann_data should provide error output + + # Create a directed graph + G = nx.DiGraph() + + # Positions dictionary to hold the position of each neuron + pos = {} + node_count = 0 + x_spacing = 1.0 + y_spacing = 1.0 + + # Calculate the maximum layer size for proper spacing + max_layer_size = max(size for size, bias in layer_sizes) + + # Build nodes and position them layer by layer from left to right + for layer_index, (layer_size, bias_count) in enumerate(layer_sizes): + y_positions = list(range(-layer_size-bias_count+1, 1, 1)) # Center-align vertically + y_positions = [y * (max_layer_size / (layer_size + bias_count)) * y_spacing for y in y_positions] # Adjust spacing + for neuron_index in range(layer_size + bias_count): # Include bias neurons + node_label = f"L{layer_index}N{neuron_index}" + G.add_node(node_count, label=node_label) + pos[node_count] = (layer_index * x_spacing, y_positions[neuron_index % len(y_positions)]) + node_count += 1 + + # Add connections to the graph + for from_neuron, to_neuron, weight in connections: + G.add_edge(from_neuron, to_neuron, weight=weight) + + max_weight = max(abs(weight) for _, _, weight in connections) + print(f"Max weight: {max_weight}") + + # Draw nodes + nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=200) + nx.draw_networkx_labels(G, pos, font_size=7) + + # Custom function for edge properties + def adjust_properties(weight): + # if weight > 0: + # print("Weight:", weight) + color = 'green' if weight > 0 else 'red' + alpha = min((abs(weight) / max_weight) ** 3, 1) + # print(f"Color: {color}, Alpha: {alpha}") + return color, alpha + + # Draw edges with custom properties + for u, v, d in G.edges(data=True): + color, alpha = adjust_properties(d['weight']) + nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=color, alpha=alpha, width=1.5, arrows=False) + + # Show plot + plt.title('FANN Network Visualization') + plt.axis('off') # Turn off the axis + plt.show() + +# Path to the FANN network file +fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_4f2be613-ab26-4384-9a65-450e043984ea\\6\\4f2be613-ab26-4384-9a65-450e043984ea_fighter_nn_0.net' +# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_fc294503-7b2a-40f8-be59-ccc486eb3f79\\0\\fc294503-7b2a-40f8-be59-ccc486eb3f79_fighter_nn_0.net" +# fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_99c30a7f-40ab-4faf-b16a-b44703fdb6cd\\0\\99c30a7f-40ab-4faf-b16a-b44703fdb6cd_fighter_nn_0.net' +# Has a 4 layer network +# # Generation 1 +# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\1\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net" +# # Generation 5 +# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\5\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net" +# # Generation 10 +# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\10\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net" +# # Generation 20 +# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\20\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net" +# # Generation 32 +# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\32\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net" +fann_path = select_file() +visualize_fann_network(fann_path) \ No newline at end of file diff --git a/visualize_simulation_tree.py b/visualize_simulation_tree.py new file mode 100644 index 0000000..7d91343 --- /dev/null +++ b/visualize_simulation_tree.py @@ -0,0 +1,104 @@ +# Re-importing necessary libraries +import json +import matplotlib.pyplot as plt +import networkx as nx +import random + +def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5): + if not nx.is_tree(G): + raise TypeError('cannot use hierarchy_pos on a graph that is not a tree') + + if root is None: + if isinstance(G, nx.DiGraph): + root = next(iter(nx.topological_sort(G))) + else: + root = random.choice(list(G.nodes)) + + def _hierarchy_pos(G, root, width=2., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None): + if pos is None: + pos = {root: (xcenter, vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.successors(root)) # Use successors to get children for DiGraph + if not isinstance(G, nx.DiGraph): + if parent is not None: + children.remove(parent) + if len(children) != 0: + dx = width / len(children) + nextx = xcenter - width / 2 - dx / 2 + for child in children: + nextx += dx + pos = _hierarchy_pos(G, child, width=dx*2.0, vert_gap=vert_gap, + vert_loc=vert_loc - vert_gap, xcenter=nextx, + pos=pos, parent=root) + return pos + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) + +# Simplified JSON data for demonstration +with open('gemla/round4.json', 'r') as file: + simplified_json_data = json.load(file) + +# Function to traverse the tree and create a graph +def traverse(node, graph, parent=None): + if node is None: + return + + node_id = node["val"]["id"] + if "node" in node["val"] and node["val"]["node"]: + scores = node["val"]["node"]["scores"] + generations = node["val"]["node"]["generation"] + population_size = node["val"]["node"]["population_size"] + # Prepare to track the highest score across all generations and the corresponding individual + overall_max_score = float('-inf') + overall_max_score_individual = None + overall_max_score_gen = None + + for gen, gen_scores in enumerate(scores): + if gen_scores: # Ensure the dictionary is not empty + # Find the max score and the individual for this generation + max_score_for_gen = max(gen_scores.values()) + individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get) + + # if max_score_for_gen > overall_max_score: + overall_max_score = max_score_for_gen + overall_max_score_individual = individual_with_max_score_for_gen + overall_max_score_gen = gen + + # print debug statement + # print(f"Node {node_id}: Max score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})") + # print(f"Left: {node.get('left')}, Right: {node.get('right')}") + label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen + 1 if overall_max_score_gen is not None else 'N/A'})" + else: + label = node_id + + graph.add_node(node_id, label=label) + if parent: + graph.add_edge(parent, node_id) + + traverse(node.get("left"), graph, parent=node_id) + traverse(node.get("right"), graph, parent=node_id) + + +# Create a directed graph +G = nx.DiGraph() + +# Populate the graph +traverse(simplified_json_data[0], G) + +# Find the root node (a node with no incoming edges) +root_candidates = [node for node, indeg in G.in_degree() if indeg == 0] + +if root_candidates: + root_node = root_candidates[0] # Assuming there's only one root candidate +else: + root_node = None # This should ideally never happen in a properly structured tree + +# Use the determined root node for hierarchy_pos +if root_node is not None: + pos = hierarchy_pos(G, root=root_node) + labels = nx.get_node_attributes(G, 'label') + nx.draw(G, pos, labels=labels, with_labels=True, arrows=True) + plt.show() +else: + print("No root node found. Cannot draw the tree.") \ No newline at end of file