Compare commits

..

No commits in common. "6725ab3feb529a00c59c085b7e021e34db3713fe" and "e8d373d4f991d47f167372cca2a04675b40a24e1" have entirely different histories.

32 changed files with 478 additions and 5279 deletions

4
.gitignore vendored
View file

@ -14,7 +14,3 @@ settings.json
.DS_Store .DS_Store
.vscode/alive .vscode/alive
# Added by cargo
/target

50
.vscode/launch.json vendored
View file

@ -10,55 +10,7 @@
"name": "Debug", "name": "Debug",
"program": "${workspaceFolder}/gemla/target/debug/gemla.exe", "program": "${workspaceFolder}/gemla/target/debug/gemla.exe",
"args": ["./gemla/temp/"], "args": ["./gemla/temp/"],
"cwd": "${workspaceFolder}/gemla" "cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug Rust Tests",
"cargo": {
"args": [
"test",
"--manifest-path", "${workspaceFolder}/gemla/Cargo.toml",
"--no-run", // Compiles the tests without running them
"--package=gemla", // Specify your package name if necessary
"--bin=bin"
],
"filter": { }
},
"args": [],
},
{
"type": "lldb",
"request": "launch",
"name": "Debug gemla Lib Tests",
"cargo": {
"args": [
"test",
"--manifest-path", "${workspaceFolder}/gemla/Cargo.toml",
"--no-run", // Compiles the tests without running them
"--package=gemla", // Specify your package name if necessary
"--lib"
],
"filter": { }
},
"args": [],
},
{
"type": "lldb",
"request": "launch",
"name": "Debug Rust FileLinked Tests",
"cargo": {
"args": [
"test",
"--manifest-path", "${workspaceFolder}/file_linked/Cargo.toml",
"--no-run", // Compiles the tests without running them
"--package=file_linked", // Specify your package name if necessary
"--lib"
],
"filter": { }
},
"args": [],
} }
] ]
} }

View file

@ -1,171 +0,0 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
# Simplified JSON data for demonstration
with open('gemla/round4.json', 'r') as file:
simplified_json_data = json.load(file)
target_node_id = '523f8250-3101-4586-90a1-127ffa6d73d9'
# Function to traverse the tree to find a node id
def traverse_left_nodes(node):
if node is None:
return []
left_node = node.get("left")
if left_node is None:
return [node]
return [node] + traverse_left_nodes(left_node)
# Function to traverse the tree to find a node id
def traverse_right_nodes(node):
if node is None:
return []
right_node = node.get("right")
left_node = node.get("left")
if right_node is None and left_node is None:
return []
elif right_node and left_node:
return [right_node] + traverse_right_nodes(left_node)
return []
# Getting the left graph
left_nodes = traverse_left_nodes(simplified_json_data[0])
left_nodes.reverse()
# print(node)
# Print properties available on the first node
node = left_nodes[0]
# print(node["val"].keys())
scores = []
for node in left_nodes:
# print(node)
# print(f'Node ID: {node["val"]["id"]}')
# print(f'Node scores length: {len(node["val"]["node"]["scores"])}')
if node["val"]["node"]:
node_scores = node["val"]["node"]["scores"]
if node_scores:
for score in node_scores:
scores.append(score)
# print(scores)
scores_values = [list(score_set.values()) for score_set in scores]
# Set up the figure for plotting on the same graph
fig, ax = plt.subplots(figsize=(10, 6))
# Generate a boxplot for each set of scores on the same graph
boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))])
# Set figure name to node id
ax.set_xscale('symlog', linthresh=1.0)
# Labeling
ax.set_xlabel(f'Scores - Main Line')
ax.set_ylabel('Score Sets')
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
# Set y-axis labels to be visible
ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))])
# Getting most recent right graph
right_nodes = traverse_right_nodes(simplified_json_data[0])
if len(right_nodes) != 0:
target_node_id = None
target_node = None
if target_node_id:
for node in right_nodes:
if node["val"]["id"] == target_node_id:
target_node = node
break
else:
target_node = right_nodes[0]
scores = target_node["val"]["node"]["scores"]
scores_values = [list(score_set.values()) for score_set in scores]
# Set up the figure for plotting on the same graph
fig, ax = plt.subplots(figsize=(10, 6))
# Generate a boxplot for each set of scores on the same graph
boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))])
ax.set_xscale('symlog', linthresh=1.0)
# Labeling
ax.set_xlabel(f'Scores: {target_node['val']['id']}')
ax.set_ylabel('Score Sets')
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
# Set y-axis labels to be visible
ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))])
# Find the highest scoring sets combining all scores and generations
scores = []
for node in left_nodes:
if node["val"]["node"]:
node_scores = node["val"]["node"]["scores"]
translated_node_scores = []
if node_scores:
for i in range(len(node_scores)):
for (individual, score) in node_scores[i].items():
translated_node_scores.append((node["val"]["id"], i, score))
scores.append(translated_node_scores)
# Add scores from the right nodes
if len(right_nodes) != 0:
for node in right_nodes:
if node["val"]["node"]:
node_scores = node["val"]["node"]["scores"]
translated_node_scores = []
if node_scores:
for i in range(len(node_scores)):
for (individual, score) in node_scores[i].items():
translated_node_scores.append((node["val"]["id"], i, score))
scores.append(translated_node_scores)
# Organize scores by individual and then by generation
individual_generation_scores = defaultdict(lambda: defaultdict(list))
for sublist in scores:
for id, generation, score in sublist:
individual_generation_scores[id][generation].append(score)
# Calculate Q3 for each individual's generation
individual_generation_q3 = {}
for id, generations in individual_generation_scores.items():
for gen, scores in generations.items():
individual_generation_q3[(id, gen)] = np.percentile(scores, 75)
# Sort by Q3 value, highest first, and select the top 20
top_20_individual_generations = sorted(individual_generation_q3, key=individual_generation_q3.get, reverse=True)[:40]
# Prepare scores for the top 20 for plotting
top_20_scores = [individual_generation_scores[id][gen] for id, gen in top_20_individual_generations]
# Adjust labels for clarity, indicating both the individual ID and generation
labels = [f'{id[:8]}... Gen {gen}' for id, gen in top_20_individual_generations]
# Generate box and whisker plots for the top 20 individual generations
fig, ax = plt.subplots(figsize=(12, 10))
ax.boxplot(top_20_scores, vert=False, patch_artist=True, labels=labels)
ax.set_xscale('symlog', linthresh=1.0)
ax.set_xlabel('Scores')
ax.set_ylabel('Individual Generation')
ax.set_title('Top 20 Individual Generations by Q3 Value')
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
# Display the plot
plt.show()

1
carp_spike/.gitignore vendored Normal file
View file

@ -0,0 +1 @@
out/

10
carp_spike/main.carp Normal file
View file

@ -0,0 +1,10 @@
(use Random)
(Project.config "title" "gemla")
(deftype SimulationNode [population-size Int, population-cutoff Int])
;; (let [test (SimulationNode.init 10 3)]
;; (do
;; (SimulationNode.set-population-size test 20)
;; (SimulationNode.population-size &test)
;; ))

View file

@ -1,9 +0,0 @@
[package]
name = "extract_fann_data"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
fann = "0.1.8"

View file

@ -1,11 +0,0 @@
fn main() {
// Replace this with the path to the directory containing `fann.lib`
let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib";
println!("cargo:rustc-link-search=native={}", lib_dir);
println!("cargo:rustc-link-lib=static=fann");
// Use `dylib=fann` instead of `static=fann` if you're linking dynamically
// If there are any additional directories where the compiler can find header files, you can specify them like this:
// println!("cargo:include={}", path_to_include_directory);
}

View file

@ -1,38 +0,0 @@
extern crate fann;
use fann::Fann;
use std::os::raw::c_uint;
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <network_file>", args[0]);
std::process::exit(1);
}
let network_file = &args[1];
match Fann::from_file(network_file) {
Ok(ann) => {
// Output layer sizes
let layer_sizes = ann.get_layer_sizes();
let bias_counts = ann.get_bias_counts();
println!("Layers:");
for (layer_size, bias_count) in layer_sizes.iter().zip(bias_counts.iter()) {
println!("{} {}", layer_size, bias_count);
}
// Output connections
println!("Connections:");
let connections = ann.get_connections();
for connection in connections {
println!("{} {} {}", connection.from_neuron, connection.to_neuron, connection.weight);
}
},
Err(err) => {
eprintln!("Error loading network from file {}: {}", network_file, err);
std::process::exit(1);
}
}
}

View file

@ -20,6 +20,3 @@ thiserror = "1.0"
anyhow = "1.0" anyhow = "1.0"
bincode = "1.3.3" bincode = "1.3.3"
log = "0.4.14" log = "0.4.14"
serde_json = "1.0.114"
tokio = { version = "1.37.0", features = ["full"] }
futures = "0.3.30"

View file

@ -1,5 +0,0 @@
#[derive(Debug)]
pub enum DataFormat {
Bincode,
Json,
}

View file

@ -1 +0,0 @@
pub mod data_format;

View file

@ -1,10 +1,8 @@
//! A wrapper around an object that ties it to a physical file //! A wrapper around an object that ties it to a physical file
pub mod constants;
pub mod error; pub mod error;
use anyhow::{anyhow, Context}; use anyhow::{anyhow, Context};
use constants::data_format::DataFormat;
use error::Error; use error::Error;
use log::info; use log::info;
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
@ -12,10 +10,9 @@ use std::{
fs::{copy, remove_file, File}, fs::{copy, remove_file, File},
io::{ErrorKind, Write}, io::{ErrorKind, Write},
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, thread,
thread::{self, JoinHandle}, thread::JoinHandle,
}; };
use tokio::sync::RwLock;
/// A wrapper around an object `T` that ties the object to a physical file /// A wrapper around an object `T` that ties the object to a physical file
#[derive(Debug)] #[derive(Debug)]
@ -23,11 +20,10 @@ pub struct FileLinked<T>
where where
T: Serialize, T: Serialize,
{ {
val: Arc<RwLock<T>>, val: T,
path: PathBuf, path: PathBuf,
temp_file_path: PathBuf, temp_file_path: PathBuf,
file_thread: Option<JoinHandle<()>>, file_thread: Option<JoinHandle<()>>,
data_format: DataFormat,
} }
impl<T> Drop for FileLinked<T> impl<T> Drop for FileLinked<T>
@ -52,12 +48,10 @@ where
/// # Examples /// # Examples
/// ``` /// ```
/// # use file_linked::*; /// # use file_linked::*;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -66,30 +60,27 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # #[tokio::main] /// # fn main() {
/// # async fn main() {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from("two"), /// b: String::from("two"),
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// let readonly = linked_test.readonly(); /// assert_eq!(linked_test.readonly().a, 1);
/// let readonly_ref = readonly.read().await; /// assert_eq!(linked_test.readonly().b, String::from("two"));
/// assert_eq!(readonly_ref.a, 1); /// assert_eq!(linked_test.readonly().c, 3.0);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
/// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # } /// # }
/// ``` /// ```
pub fn readonly(&self) -> Arc<RwLock<T>> { pub fn readonly(&self) -> &T {
self.val.clone() &self.val
} }
/// Creates a new [`FileLinked`] object of type `T` stored to the file given by `path`. /// Creates a new [`FileLinked`] object of type `T` stored to the file given by `path`.
@ -97,12 +88,10 @@ where
/// # Examples /// # Examples
/// ``` /// ```
/// # use file_linked::*; /// # use file_linked::*;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// #[derive(Deserialize, Serialize)] /// #[derive(Deserialize, Serialize)]
/// struct Test { /// struct Test {
@ -111,29 +100,26 @@ where
/// pub c: f64 /// pub c: f64
/// } /// }
/// ///
/// #[tokio::main] /// # fn main() {
/// # async fn main() {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from("two"), /// b: String::from("two"),
/// c: 3.0 /// c: 3.0
/// }; /// };
/// ///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await /// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// let readonly = linked_test.readonly(); /// assert_eq!(linked_test.readonly().a, 1);
/// let readonly_ref = readonly.read().await; /// assert_eq!(linked_test.readonly().b, String::from("two"));
/// assert_eq!(readonly_ref.a, 1); /// assert_eq!(linked_test.readonly().c, 3.0);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
/// # std::fs::remove_file("./temp").expect("Unable to remove file"); /// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # } /// # }
/// ``` /// ```
pub async fn new(val: T, path: &Path, data_format: DataFormat) -> Result<FileLinked<T>, Error> { pub fn new(val: T, path: &Path) -> Result<FileLinked<T>, Error> {
let mut temp_file_path = path.to_path_buf(); let mut temp_file_path = path.to_path_buf();
temp_file_path.set_file_name(format!( temp_file_path.set_file_name(format!(
".temp{}", ".temp{}",
@ -144,28 +130,21 @@ where
)); ));
let mut result = FileLinked { let mut result = FileLinked {
val: Arc::new(RwLock::new(val)), val,
path: path.to_path_buf(), path: path.to_path_buf(),
temp_file_path, temp_file_path,
file_thread: None, file_thread: None,
data_format,
}; };
result.write_data().await?; result.write_data()?;
Ok(result) Ok(result)
} }
async fn write_data(&mut self) -> Result<(), Error> { fn write_data(&mut self) -> Result<(), Error> {
let thread_path = self.path.clone(); let thread_path = self.path.clone();
let thread_temp_path = self.temp_file_path.clone(); let thread_temp_path = self.temp_file_path.clone();
let val = self.val.read().await; let thread_val = bincode::serialize(&self.val)
.with_context(|| "Unable to serialize object into bincode".to_string())?;
let thread_val = match self.data_format {
DataFormat::Bincode => bincode::serialize(&*val)
.with_context(|| "Unable to serialize object into bincode".to_string())?,
DataFormat::Json => serde_json::to_vec(&*val)
.with_context(|| "Unable to serialize object into JSON".to_string())?,
};
if let Some(file_thread) = self.file_thread.take() { if let Some(file_thread) = self.file_thread.take() {
file_thread file_thread
@ -211,12 +190,10 @@ where
/// ``` /// ```
/// # use file_linked::*; /// # use file_linked::*;
/// # use file_linked::error::Error; /// # use file_linked::error::Error;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -225,28 +202,21 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # #[tokio::main] /// # fn main() -> Result<(), Error> {
/// # async fn main() -> Result<(), Error> {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from(""), /// b: String::from(""),
/// c: 0.0 /// c: 0.0
/// }; /// };
/// ///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// { /// assert_eq!(linked_test.readonly().a, 1);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// }
/// ///
/// linked_test.mutate(|t| t.a = 2).await?; /// linked_test.mutate(|t| t.a = 2)?;
/// ///
/// let readonly = linked_test.readonly(); /// assert_eq!(linked_test.readonly().a, 2);
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -255,15 +225,10 @@ where
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub async fn mutate<U, F: FnOnce(&mut T) -> U>(&mut self, op: F) -> Result<U, Error> { pub fn mutate<U, F: FnOnce(&mut T) -> U>(&mut self, op: F) -> Result<U, Error> {
let val_clone = self.val.clone(); // Arc<RwLock<T>> let result = op(&mut self.val);
let mut val = val_clone.write().await; // RwLockWriteGuard<T>
let result = op(&mut val); self.write_data()?;
drop(val);
self.write_data().await?;
Ok(result) Ok(result)
} }
@ -274,12 +239,10 @@ where
/// ``` /// ```
/// # use file_linked::*; /// # use file_linked::*;
/// # use file_linked::error::Error; /// # use file_linked::error::Error;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -288,30 +251,25 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # #[tokio::main] /// # fn main() -> Result<(), Error> {
/// # async fn main() -> Result<(), Error> {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from(""), /// b: String::from(""),
/// c: 0.0 /// c: 0.0
/// }; /// };
/// ///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await /// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// let readonly = linked_test.readonly(); /// assert_eq!(linked_test.readonly().a, 1);
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// ///
/// linked_test.replace(Test { /// linked_test.replace(Test {
/// a: 2, /// a: 2,
/// b: String::from(""), /// b: String::from(""),
/// c: 0.0 /// c: 0.0
/// }).await?; /// })?;
/// ///
/// let readonly = linked_test.readonly(); /// assert_eq!(linked_test.readonly().a, 2);
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -320,30 +278,10 @@ where
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub async fn replace(&mut self, val: T) -> Result<(), Error> { pub fn replace(&mut self, val: T) -> Result<(), Error> {
self.val = Arc::new(RwLock::new(val)); self.val = val;
self.write_data().await self.write_data()
}
}
impl<T> FileLinked<T>
where
T: Serialize + DeserializeOwned + Send + 'static,
{
/// Asynchronously modifies the data contained in a `FileLinked` object using an async callback `op`.
pub async fn mutate_async<F, Fut, U>(&mut self, op: F) -> Result<U, Error>
where
F: FnOnce(Arc<RwLock<T>>) -> Fut,
Fut: std::future::Future<Output = U> + Send,
U: Send,
{
let val_clone = self.val.clone();
let result = op(val_clone).await;
self.write_data().await?;
Ok(result)
} }
} }
@ -357,7 +295,6 @@ where
/// ``` /// ```
/// # use file_linked::*; /// # use file_linked::*;
/// # use file_linked::error::Error; /// # use file_linked::error::Error;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize}; /// # use serde::{Deserialize, Serialize};
/// # use std::fmt; /// # use std::fmt;
/// # use std::string::ToString; /// # use std::string::ToString;
@ -365,7 +302,6 @@ where
/// # use std::fs::OpenOptions; /// # use std::fs::OpenOptions;
/// # use std::io::Write; /// # use std::io::Write;
/// # use std::path::PathBuf; /// # use std::path::PathBuf;
/// # use tokio;
/// # /// #
/// # #[derive(Deserialize, Serialize)] /// # #[derive(Deserialize, Serialize)]
/// # struct Test { /// # struct Test {
@ -374,8 +310,7 @@ where
/// # pub c: f64 /// # pub c: f64
/// # } /// # }
/// # /// #
/// # #[tokio::main] /// # fn main() -> Result<(), Error> {
/// # async fn main() -> Result<(), Error> {
/// let test = Test { /// let test = Test {
/// a: 1, /// a: 1,
/// b: String::from("2"), /// b: String::from("2"),
@ -392,14 +327,12 @@ where
/// ///
/// bincode::serialize_into(file, &test).expect("Unable to serialize object"); /// bincode::serialize_into(file, &test).expect("Unable to serialize object");
/// ///
/// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode) /// let mut linked_test = FileLinked::<Test>::from_file(&path)
/// .expect("Unable to create file linked object"); /// .expect("Unable to create file linked object");
/// ///
/// let readonly = linked_test.readonly(); /// assert_eq!(linked_test.readonly().a, test.a);
/// let readonly_ref = readonly.read().await; /// assert_eq!(linked_test.readonly().b, test.b);
/// assert_eq!(readonly_ref.a, test.a); /// assert_eq!(linked_test.readonly().c, test.c);
/// assert_eq!(readonly_ref.b, test.b);
/// assert_eq!(readonly_ref.c, test.c);
/// # /// #
/// # drop(linked_test); /// # drop(linked_test);
/// # /// #
@ -408,7 +341,7 @@ where
/// # Ok(()) /// # Ok(())
/// # } /// # }
/// ``` /// ```
pub fn from_file(path: &Path, data_format: DataFormat) -> Result<FileLinked<T>, Error> { pub fn from_file(path: &Path) -> Result<FileLinked<T>, Error> {
let mut temp_file_path = path.to_path_buf(); let mut temp_file_path = path.to_path_buf();
temp_file_path.set_file_name(format!( temp_file_path.set_file_name(format!(
".temp{}", ".temp{}",
@ -418,22 +351,16 @@ where
.ok_or_else(|| anyhow!("Unable to get filename for tempfile {}", path.display()))? .ok_or_else(|| anyhow!("Unable to get filename for tempfile {}", path.display()))?
)); ));
match File::open(path) match File::open(path).map_err(Error::from).and_then(|file| {
bincode::deserialize_from::<File, T>(file)
.with_context(|| format!("Unable to deserialize file {}", path.display()))
.map_err(Error::from) .map_err(Error::from)
.and_then(|file| match data_format {
DataFormat::Bincode => bincode::deserialize_from::<File, T>(file)
.with_context(|| format!("Unable to deserialize file {}", path.display()))
.map_err(Error::from),
DataFormat::Json => serde_json::from_reader(file)
.with_context(|| format!("Unable to deserialize file {}", path.display()))
.map_err(Error::from),
}) { }) {
Ok(val) => Ok(FileLinked { Ok(val) => Ok(FileLinked {
val: Arc::new(RwLock::new(val)), val,
path: path.to_path_buf(), path: path.to_path_buf(),
temp_file_path, temp_file_path,
file_thread: None, file_thread: None,
data_format,
}), }),
Err(err) => { Err(err) => {
info!( info!(
@ -443,43 +370,30 @@ where
); );
// Try to use temp file instead and see if that file exists and is serializable // Try to use temp file instead and see if that file exists and is serializable
let val = FileLinked::from_temp_file(&temp_file_path, path, &data_format) let val = FileLinked::from_temp_file(&temp_file_path, path)
.map_err(|_| err) .map_err(|_| err)
.with_context(|| format!("Failed to read/deserialize the object from the file {} and temp file {}", path.display(), temp_file_path.display()))?; .with_context(|| format!("Failed to read/deserialize the object from the file {} and temp file {}", path.display(), temp_file_path.display()))?;
Ok(FileLinked { Ok(FileLinked {
val: Arc::new(RwLock::new(val)), val,
path: path.to_path_buf(), path: path.to_path_buf(),
temp_file_path, temp_file_path,
file_thread: None, file_thread: None,
data_format,
}) })
} }
} }
} }
fn from_temp_file( fn from_temp_file(temp_file_path: &Path, path: &Path) -> Result<T, Error> {
temp_file_path: &Path,
path: &Path,
data_format: &DataFormat,
) -> Result<T, Error> {
let file = File::open(temp_file_path) let file = File::open(temp_file_path)
.with_context(|| format!("Unable to open file {}", temp_file_path.display()))?; .with_context(|| format!("Unable to open file {}", temp_file_path.display()))?;
let val = match data_format { let val = bincode::deserialize_from(file).with_context(|| {
DataFormat::Bincode => bincode::deserialize_from(file).with_context(|| {
format!( format!(
"Could not deserialize from temp file {}", "Could not deserialize from temp file {}",
temp_file_path.display() temp_file_path.display()
) )
})?, })?;
DataFormat::Json => serde_json::from_reader(file).with_context(|| {
format!(
"Could not deserialize from temp file {}",
temp_file_path.display()
)
})?,
};
info!("Successfully deserialized value from temp file"); info!("Successfully deserialized value from temp file");
@ -507,12 +421,8 @@ mod tests {
} }
} }
pub async fn run<F, Fut>(&self, op: F) -> () pub fn run<F: FnOnce(&Path) -> Result<(), Error>>(&self, op: F) -> Result<(), Error> {
where op(&self.path)
F: FnOnce(PathBuf) -> Fut,
Fut: std::future::Future<Output = ()>,
{
op(self.path.clone()).await
} }
} }
@ -524,173 +434,92 @@ mod tests {
} }
} }
#[tokio::test] #[test]
async fn test_readonly() { fn test_readonly() -> Result<(), Error> {
let path = PathBuf::from("test_readonly"); let path = PathBuf::from("test_readonly");
let cleanup = CleanUp::new(&path); let cleanup = CleanUp::new(&path);
cleanup cleanup.run(|p| {
.run(|p| async move {
let val = vec!["one", "two", ""]; let val = vec!["one", "two", ""];
let linked_object = FileLinked::new(val.clone(), &p, DataFormat::Json) let linked_object = FileLinked::new(val.clone(), &p)?;
.await assert_eq!(*linked_object.readonly(), val);
.expect("Unable to create file linked object");
let linked_object_arc = linked_object.readonly(); Ok(())
let linked_object_ref = linked_object_arc.read().await;
assert_eq!(*linked_object_ref, val);
}) })
.await;
} }
#[tokio::test] #[test]
async fn test_new() { fn test_new() -> Result<(), Error> {
let path = PathBuf::from("test_new"); let path = PathBuf::from("test_new");
let cleanup = CleanUp::new(&path); let cleanup = CleanUp::new(&path);
cleanup cleanup.run(|p| {
.run(|p| async move {
let val = "test"; let val = "test";
FileLinked::new(val, &p, DataFormat::Bincode) FileLinked::new(val, &p)?;
.await
.expect("Unable to create file linked object");
let file = File::open(&p).expect("Unable to open file"); let file = File::open(&p)?;
let result: String = let result: String =
bincode::deserialize_from(file).expect("Unable to deserialize from file"); bincode::deserialize_from(file).expect("Unable to deserialize from file");
assert_eq!(result, val); assert_eq!(result, val);
Ok(())
}) })
.await;
} }
#[tokio::test] #[test]
async fn test_mutate() { fn test_mutate() -> Result<(), Error> {
let path = PathBuf::from("test_mutate"); let path = PathBuf::from("test_mutate");
let cleanup = CleanUp::new(&path); let cleanup = CleanUp::new(&path);
cleanup cleanup.run(|p| {
.run(|p| async move {
let list = vec![1, 2, 3, 4]; let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json) let mut file_linked_list = FileLinked::new(list, &p)?;
.await assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4]);
.expect("Unable to create file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]); file_linked_list.mutate(|v1| v1.push(5))?;
assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4, 5]);
drop(file_linked_list_ref); file_linked_list.mutate(|v1| v1[1] = 1)?;
file_linked_list assert_eq!(*file_linked_list.readonly(), vec![1, 1, 3, 4, 5]);
.mutate(|v1| v1.push(5))
.await
.expect("Error mutating file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4, 5]);
drop(file_linked_list_ref);
file_linked_list
.mutate(|v1| v1[1] = 1)
.await
.expect("Error mutating file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]);
drop(file_linked_list); drop(file_linked_list);
Ok(())
}) })
.await;
} }
#[tokio::test] #[test]
async fn test_async_mutate() { fn test_replace() -> Result<(), Error> {
let path = PathBuf::from("test_async_mutate");
let cleanup = CleanUp::new(&path);
cleanup
.run(|p| async move {
let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json)
.await
.expect("Unable to create file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]);
drop(file_linked_list_ref);
file_linked_list
.mutate_async(|v1| async move {
let mut v = v1.write().await;
v.push(5);
v[1] = 1;
Ok::<(), Error>(())
})
.await
.expect("Error mutating file linked object")
.expect("Error mutating file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]);
drop(file_linked_list);
})
.await;
}
#[tokio::test]
async fn test_replace() {
let path = PathBuf::from("test_replace"); let path = PathBuf::from("test_replace");
let cleanup = CleanUp::new(&path); let cleanup = CleanUp::new(&path);
cleanup cleanup.run(|p| {
.run(|p| async move {
let val1 = String::from("val1"); let val1 = String::from("val1");
let val2 = String::from("val2"); let val2 = String::from("val2");
let mut file_linked_list = FileLinked::new(val1.clone(), &p, DataFormat::Bincode) let mut file_linked_list = FileLinked::new(val1.clone(), &p)?;
.await assert_eq!(*file_linked_list.readonly(), val1);
.expect("Unable to create file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, val1); file_linked_list.replace(val2.clone())?;
assert_eq!(*file_linked_list.readonly(), val2);
file_linked_list
.replace(val2.clone())
.await
.expect("Error replacing file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, val2);
drop(file_linked_list); drop(file_linked_list);
Ok(())
}) })
.await;
} }
#[tokio::test] #[test]
async fn test_from_file() { fn test_from_file() -> Result<(), Error> {
let path = PathBuf::from("test_from_file"); let path = PathBuf::from("test_from_file");
let cleanup = CleanUp::new(&path); let cleanup = CleanUp::new(&path);
cleanup cleanup.run(|p| {
.run(|p| async move {
let value: Vec<f64> = vec![2.0, 3.0, 5.0]; let value: Vec<f64> = vec![2.0, 3.0, 5.0];
let file = File::create(&p).expect("Unable to create file"); let file = File::create(&p)?;
bincode::serialize_into(&file, &value).expect("Unable to serialize into file"); bincode::serialize_into(&file, &value).expect("Unable to serialize into file");
drop(file); drop(file);
let linked_object: FileLinked<Vec<f64>> = let linked_object: FileLinked<Vec<f64>> = FileLinked::from_file(&p)?;
FileLinked::from_file(&p, DataFormat::Bincode) assert_eq!(*linked_object.readonly(), value);
.expect("Unable to create file linked object");
let linked_object_arc = linked_object.readonly();
let linked_object_ref = linked_object_arc.read().await;
assert_eq!(*linked_object_ref, value);
drop(linked_object); drop(linked_object);
Ok(())
}) })
.await;
} }
} }

View file

@ -15,22 +15,18 @@ categories = ["simulation"]
[dependencies] [dependencies]
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
uuid = { version = "1.7", features = ["serde", "v4"] } uuid = { version = "0.8", features = ["serde", "v4"] }
clap = { version = "4.5.2", features = ["derive"] } clap = { version = "~2.27.0", features = ["yaml"] }
toml = "0.8.10" toml = "0.5.8"
regex = "1" regex = "1"
file_linked = { version = "0.1.0", path = "../file_linked" } file_linked = { version = "0.1.0", path = "../file_linked" }
thiserror = "1.0" thiserror = "1.0"
anyhow = "1.0" anyhow = "1.0"
rand = "0.8.5" rand = "0.8.4"
log = "0.4.21" log = "0.4.14"
env_logger = "0.11.3" env_logger = "0.9.0"
futures = "0.3.30" futures = "0.3.17"
tokio = { version = "1.37.0", features = ["full"] } smol = "1.2.5"
num_cpus = "1.16.0" smol-potat = "1.1.2"
easy-parallel = "3.3.1" num_cpus = "1.13.0"
fann = "0.1.8" easy-parallel = "3.1.0"
async-trait = "0.1.78"
async-recursion = "1.1.0"
lerp = "0.5.0"
console-subscriber = "0.2.0"

View file

@ -1,11 +0,0 @@
fn main() {
// Replace this with the path to the directory containing `fann.lib`
let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib";
println!("cargo:rustc-link-search=native={}", lib_dir);
println!("cargo:rustc-link-lib=static=fann");
// Use `dylib=fann` instead of `static=fann` if you're linking dynamically
// If there are any additional directories where the compiler can find header files, you can specify them like this:
// println!("cargo:include={}", path_to_include_directory);
}

9
gemla/cli.yml Normal file
View file

@ -0,0 +1,9 @@
name: GEMLA
version: "0.1"
autor: Jacob VanDomelen <jacob.vandome15@gmail.com>
about: Uses a genetic algorithm to generate a machine learning algorithm.
args:
- FILE:
help: Sets the input/output file for the program.
required: true
index: 1

15
gemla/nodes.toml Normal file
View file

@ -0,0 +1,15 @@
[[nodes]]
fabric_addr = "10.0.0.1:9999"
bridge_bind = "10.0.0.1:8888"
mem = "100 GiB"
cpu = 8
# [[nodes]]
# fabric_addr = "10.0.0.2:9999"
# mem = "100 GiB"
# cpu = 16
# [[nodes]]
# fabric_addr = "10.0.0.3:9999"
# mem = "100 GiB"
# cpu = 16

View file

@ -1,72 +1,74 @@
#[macro_use]
extern crate clap; extern crate clap;
extern crate gemla; extern crate gemla;
#[macro_use] #[macro_use]
extern crate log; extern crate log;
mod fighter_nn;
mod test_state; mod test_state;
use anyhow::Result; use anyhow::anyhow;
use clap::Parser; use clap::App;
use fighter_nn::FighterNN; use easy_parallel::Parallel;
use file_linked::constants::data_format::DataFormat;
use gemla::{ use gemla::{
constants::args::FILE,
core::{Gemla, GemlaConfig}, core::{Gemla, GemlaConfig},
error::log_error, error::{log_error, Error},
}; };
use smol::{channel, channel::RecvError, future, Executor};
use std::{path::PathBuf, time::Instant}; use std::{path::PathBuf, time::Instant};
use test_state::TestState;
// const NUM_THREADS: usize = 2;
#[derive(Parser)]
#[command(version, about, long_about = None)]
struct Args {
/// The file to read/write the dataset from/to.
#[arg(short, long)]
file: String,
}
/// Runs a simluation of a genetic algorithm against a dataset. /// Runs a simluation of a genetic algorithm against a dataset.
/// ///
/// Use the -h, --h, or --help flag to see usage syntax. /// Use the -h, --h, or --help flag to see usage syntax.
/// TODO /// TODO
fn main() -> Result<()> { fn main() -> anyhow::Result<()> {
env_logger::init(); env_logger::init();
// console_subscriber::init();
info!("Starting"); info!("Starting");
let now = Instant::now(); let now = Instant::now();
// Manually configure the Tokio runtime // Obtainning number of threads to use
let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread() let num_threads = num_cpus::get().max(1);
.worker_threads(num_cpus::get()) let ex = Executor::new();
// .worker_threads(NUM_THREADS) let (signal, shutdown) = channel::unbounded::<()>();
.build()?
.block_on(async {
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
let mut gemla = log_error(
Gemla::<FighterNN>::new(
&PathBuf::from(args.file),
GemlaConfig { overwrite: false },
DataFormat::Json,
)
.await,
)?;
// let gemla_arc = Arc::new(gemla); // Create an executor thread pool.
let (_, result): (Vec<Result<(), RecvError>>, Result<(), Error>) = Parallel::new()
.each(0..num_threads, |_| {
future::block_on(ex.run(shutdown.recv()))
})
.finish(|| {
smol::block_on(async {
drop(signal);
// Setup your application logic here // Command line arguments are parsed with the clap crate. And this program uses
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks // the yaml method with clap.
let yaml = load_yaml!("../../cli.yml");
let matches = App::from_yaml(yaml).get_matches();
// Example placeholder loop to continuously run simulate // Checking that the first argument <FILE> is a valid file
loop { if let Some(file_path) = matches.value_of(FILE) {
// Arbitrary loop count for demonstration let mut gemla = log_error(Gemla::<TestState>::new(
gemla.simulate(1).await?; &PathBuf::from(file_path),
GemlaConfig {
generations_per_node: 3,
overwrite: true,
},
))?;
log_error(gemla.simulate(3).await)?;
Ok(())
} else {
Err(Error::Other(anyhow!("Invalid argument for FILE")))
} }
})
}); });
runtime?; // Handle errors from the block_on call result?;
info!("Finished in {:?}", now.elapsed()); info!("Finished in {:?}", now.elapsed());
Ok(()) Ok(())
} }

View file

@ -1,79 +0,0 @@
use std::sync::Arc;
use serde::ser::SerializeTuple;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tokio::sync::Semaphore;
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50;
const VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT: usize = 1;
#[derive(Debug, Clone)]
pub struct FighterContext {
pub shared_semaphore: Arc<Semaphore>,
pub visible_simulations: Arc<Semaphore>,
}
impl Default for FighterContext {
fn default() -> Self {
FighterContext {
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)),
}
}
}
// Custom serialization to just output the concurrency limit.
impl Serialize for FighterContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Assuming the semaphore's available permits represent the concurrency limit.
// This part is tricky since Semaphore does not expose its initial permits.
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
let visible_concurrency_limit = VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT;
// serializer.serialize_u64(concurrency_limit as u64)
// Serialize the concurrency limit as a tuple
let mut state = serializer.serialize_tuple(2)?;
state.serialize_element(&concurrency_limit)?;
state.serialize_element(&visible_concurrency_limit)?;
state.end()
}
}
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
impl<'de> Deserialize<'de> for FighterContext {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Deserialize the tuple
let (_, _) = <(usize, usize)>::deserialize(deserializer)?;
Ok(FighterContext {
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialization() {
let context = FighterContext::default();
let serialized = serde_json::to_string(&context).unwrap();
let deserialized: FighterContext = serde_json::from_str(&serialized).unwrap();
assert_eq!(
context.shared_semaphore.available_permits(),
deserialized.shared_semaphore.available_permits()
);
assert_eq!(
context.visible_simulations.available_permits(),
deserialized.visible_simulations.available_permits()
);
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,11 +1,6 @@
use async_trait::async_trait; use gemla::{core::genetic_node::GeneticNode, error::Error};
use gemla::{
core::genetic_node::{GeneticNode, GeneticNodeContext},
error::Error,
};
use rand::prelude::*; use rand::prelude::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use uuid::Uuid;
const POPULATION_SIZE: u64 = 5; const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3; const POPULATION_REDUCTION_SIZE: u64 = 3;
@ -13,30 +8,20 @@ const POPULATION_REDUCTION_SIZE: u64 = 3;
#[derive(Serialize, Deserialize, Debug, Clone)] #[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TestState { pub struct TestState {
pub population: Vec<i64>, pub population: Vec<i64>,
pub max_generations: u64,
} }
#[async_trait]
impl GeneticNode for TestState { impl GeneticNode for TestState {
type Context = (); fn initialize() -> Result<Box<Self>, Error> {
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
let mut population: Vec<i64> = vec![]; let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE { for _ in 0..POPULATION_SIZE {
population.push(thread_rng().gen_range(0..100)) population.push(thread_rng().gen_range(0..100))
} }
Ok(Box::new(TestState { Ok(Box::new(TestState { population }))
population,
max_generations: 10,
}))
} }
async fn simulate( fn simulate(&mut self) -> Result<(), Error> {
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
let mut rng = thread_rng(); let mut rng = thread_rng();
self.population = self self.population = self
@ -45,14 +30,10 @@ impl GeneticNode for TestState {
.map(|p| p.saturating_add(rng.gen_range(-1..2))) .map(|p| p.saturating_add(rng.gen_range(-1..2)))
.collect(); .collect();
if context.generation >= self.max_generations { Ok(())
Ok(false)
} else {
Ok(true)
}
} }
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> { fn mutate(&mut self) -> Result<(), Error> {
let mut rng = thread_rng(); let mut rng = thread_rng();
let mut v = self.population.clone(); let mut v = self.population.clone();
@ -90,12 +71,7 @@ impl GeneticNode for TestState {
Ok(()) Ok(())
} }
async fn merge( fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
left: &TestState,
right: &TestState,
id: &Uuid,
gemla_context: Self::Context,
) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone(); let mut v = left.population.clone();
v.append(&mut right.population.clone()); v.append(&mut right.population.clone());
@ -104,18 +80,9 @@ impl GeneticNode for TestState {
v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec(); v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec();
let mut result = TestState { let mut result = TestState { population: v };
population: v,
max_generations: 10,
};
result result.mutate()?;
.mutate(GeneticNodeContext {
id: *id,
generation: 0,
gemla_context,
})
.await?;
Ok(Box::new(result)) Ok(Box::new(result))
} }
@ -126,97 +93,57 @@ mod tests {
use super::*; use super::*;
use gemla::core::genetic_node::GeneticNode; use gemla::core::genetic_node::GeneticNode;
#[tokio::test] #[test]
async fn test_initialize() { fn test_initialize() {
let state = TestState::initialize(GeneticNodeContext { let state = TestState::initialize().unwrap();
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize); assert_eq!(state.population.len(), POPULATION_SIZE as usize);
} }
#[tokio::test] #[test]
async fn test_simulate() { fn test_simulate() {
let mut state = TestState { let mut state = TestState {
population: vec![1, 1, 2, 3], population: vec![1, 1, 2, 3],
max_generations: 1,
}; };
let original_population = state.population.clone(); let original_population = state.population.clone();
state state.simulate().unwrap();
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert!(original_population assert!(original_population
.iter() .iter()
.zip(state.population.iter()) .zip(state.population.iter())
.all(|(&a, &b)| b >= a - 1 && b <= a + 2)); .all(|(&a, &b)| b >= a - 1 && b <= a + 2));
state state.simulate().unwrap();
.simulate(GeneticNodeContext { state.simulate().unwrap();
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert!(original_population assert!(original_population
.iter() .iter()
.zip(state.population.iter()) .zip(state.population.iter())
.all(|(&a, &b)| b >= a - 3 && b <= a + 6)) .all(|(&a, &b)| b >= a - 3 && b <= a + 6))
} }
#[tokio::test] #[test]
async fn test_mutate() { fn test_mutate() {
let mut state = TestState { let mut state = TestState {
population: vec![4, 3, 3], population: vec![4, 3, 3],
max_generations: 1,
}; };
state state.mutate().unwrap();
.mutate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize); assert_eq!(state.population.len(), POPULATION_SIZE as usize);
} }
#[tokio::test] #[test]
async fn test_merge() { fn test_merge() {
let state1 = TestState { let state1 = TestState {
population: vec![1, 2, 4, 5], population: vec![1, 2, 4, 5],
max_generations: 1,
}; };
let state2 = TestState { let state2 = TestState {
population: vec![0, 1, 3, 7], population: vec![0, 1, 3, 7],
max_generations: 1,
}; };
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ()) let merged_state = TestState::merge(&state1, &state2).unwrap();
.await
.unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize); assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7)); assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -0,0 +1,2 @@
/// Corresponds to the FILE command line argument used in accordance with the clap crate.
pub const FILE: &str = "FILE";

View file

@ -0,0 +1 @@
pub mod args;

View file

@ -5,9 +5,7 @@
use crate::error::Error; use crate::error::Error;
use anyhow::Context; use anyhow::Context;
use async_trait::async_trait; use serde::{Deserialize, Serialize};
use log::info;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug; use std::fmt::Debug;
use uuid::Uuid; use uuid::Uuid;
@ -26,65 +24,45 @@ pub enum GeneticState {
Finish, Finish,
} }
#[derive(Clone, Debug)]
pub struct GeneticNodeContext<S> {
pub generation: u64,
pub id: Uuid,
pub gemla_context: S,
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`] /// A trait used to interact with the internal state of nodes within the [`Bracket`]
/// ///
/// [`Bracket`]: crate::bracket::Bracket /// [`Bracket`]: crate::bracket::Bracket
#[async_trait] pub trait GeneticNode {
pub trait GeneticNode: Send {
type Context;
/// Initializes a new instance of a [`GeneticState`]. /// Initializes a new instance of a [`GeneticState`].
/// ///
/// # Examples /// # Examples
/// TODO /// TODO
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>; fn initialize() -> Result<Box<Self>, Error>;
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) fn simulate(&mut self) -> Result<(), Error>;
-> Result<bool, Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring. /// Mutates members in a population and/or crossbreeds them to produce new offspring.
/// ///
/// # Examples /// # Examples
/// TODO /// TODO
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>; fn mutate(&mut self) -> Result<(), Error>;
async fn merge( fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
left: &Self,
right: &Self,
id: &Uuid,
context: Self::Context,
) -> Result<Box<Self>, Error>;
} }
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as /// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
/// well as signal recovery. Transition states are given by [`GeneticState`] /// well as signal recovery. Transition states are given by [`GeneticState`]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct GeneticNodeWrapper<T> pub struct GeneticNodeWrapper<T> {
where
T: Clone,
{
node: Option<T>, node: Option<T>,
state: GeneticState, state: GeneticState,
generation: u64, generation: u64,
max_generations: u64,
id: Uuid, id: Uuid,
} }
impl<T> Default for GeneticNodeWrapper<T> impl<T> Default for GeneticNodeWrapper<T> {
where
T: Clone,
{
fn default() -> Self { fn default() -> Self {
GeneticNodeWrapper { GeneticNodeWrapper {
node: None, node: None,
state: GeneticState::Initialize, state: GeneticState::Initialize,
generation: 1, generation: 1,
max_generations: 1,
id: Uuid::new_v4(), id: Uuid::new_v4(),
} }
} }
@ -92,20 +70,21 @@ where
impl<T> GeneticNodeWrapper<T> impl<T> GeneticNodeWrapper<T>
where where
T: GeneticNode + Debug + Send + Clone, T: GeneticNode + Debug,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub fn new() -> Self { pub fn new(max_generations: u64) -> Self {
GeneticNodeWrapper::<T> { GeneticNodeWrapper::<T> {
max_generations,
..Default::default() ..Default::default()
} }
} }
pub fn from(data: T, id: Uuid) -> Self { pub fn from(data: T, max_generations: u64, id: Uuid) -> Self {
GeneticNodeWrapper { GeneticNodeWrapper {
node: Some(data), node: Some(data),
state: GeneticState::Simulate, state: GeneticState::Simulate,
generation: 1, generation: 1,
max_generations,
id, id,
} }
} }
@ -114,51 +93,36 @@ where
self.node.as_ref() self.node.as_ref()
} }
pub fn take(&mut self) -> Option<T> {
self.node.take()
}
pub fn id(&self) -> Uuid { pub fn id(&self) -> Uuid {
self.id self.id
} }
pub fn generation(&self) -> u64 { pub fn max_generations(&self) -> u64 {
self.generation self.max_generations
} }
pub fn state(&self) -> GeneticState { pub fn state(&self) -> GeneticState {
self.state self.state
} }
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> { pub fn process_node(&mut self) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
id: self.id,
gemla_context,
};
match (self.state, &mut self.node) { match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => { (GeneticState::Initialize, _) => {
self.node = Some(*T::initialize(context.clone()).await?); self.node = Some(*T::initialize()?);
self.state = GeneticState::Simulate; self.state = GeneticState::Simulate;
} }
(GeneticState::Simulate, Some(n)) => { (GeneticState::Simulate, Some(n)) => {
let next_generation = n n.simulate()
.simulate(context.clone())
.await
.with_context(|| format!("Error simulating node: {:?}", self))?; .with_context(|| format!("Error simulating node: {:?}", self))?;
info!("Simulation complete and continuing: {:?}", next_generation); self.state = if self.generation >= self.max_generations {
self.state = if next_generation {
GeneticState::Mutate
} else {
GeneticState::Finish GeneticState::Finish
} else {
GeneticState::Mutate
}; };
} }
(GeneticState::Mutate, Some(n)) => { (GeneticState::Mutate, Some(n)) => {
n.mutate(context.clone()) n.mutate()
.await
.with_context(|| format!("Error mutating node: {:?}", self))?; .with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1; self.generation += 1;
@ -177,64 +141,40 @@ mod tests {
use super::*; use super::*;
use crate::error::Error; use crate::error::Error;
use anyhow::anyhow; use anyhow::anyhow;
use async_trait::async_trait;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState { struct TestState {
pub score: f64, pub score: f64,
pub max_generations: u64,
} }
#[async_trait]
impl GeneticNode for TestState { impl GeneticNode for TestState {
type Context = (); fn simulate(&mut self) -> Result<(), Error> {
async fn simulate(
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
self.score += 1.0; self.score += 1.0;
if context.generation >= self.max_generations {
Ok(false)
} else {
Ok(true)
}
}
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(()) Ok(())
} }
async fn initialize( fn mutate(&mut self) -> Result<(), Error> {
_context: GeneticNodeContext<Self::Context>, Ok(())
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState {
score: 0.0,
max_generations: 2,
}))
} }
async fn merge( fn initialize() -> Result<Box<TestState>, Error> {
_l: &TestState, Ok(Box::new(TestState { score: 0.0 }))
_r: &TestState, }
_id: &Uuid,
_: Self::Context, fn merge(_l: &TestState, _r: &TestState) -> Result<Box<TestState>, Error> {
) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge"))) Err(Error::Other(anyhow!("Unable to merge")))
} }
} }
#[test] #[test]
fn test_new() -> Result<(), Error> { fn test_new() -> Result<(), Error> {
let genetic_node = GeneticNodeWrapper::<TestState>::new(); let genetic_node = GeneticNodeWrapper::<TestState>::new(10);
let other_genetic_node = GeneticNodeWrapper::<TestState> { let other_genetic_node = GeneticNodeWrapper::<TestState> {
node: None, node: None,
state: GeneticState::Initialize, state: GeneticState::Initialize,
generation: 1, generation: 1,
max_generations: 10,
id: genetic_node.id(), id: genetic_node.id(),
}; };
@ -245,17 +185,15 @@ mod tests {
#[test] #[test]
fn test_from() -> Result<(), Error> { fn test_from() -> Result<(), Error> {
let val = TestState { let val = TestState { score: 0.0 };
score: 0.0,
max_generations: 10,
};
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let other_genetic_node = GeneticNodeWrapper::<TestState> { let other_genetic_node = GeneticNodeWrapper::<TestState> {
node: Some(val), node: Some(val),
state: GeneticState::Simulate, state: GeneticState::Simulate,
generation: 1, generation: 1,
max_generations: 10,
id: genetic_node.id(), id: genetic_node.id(),
}; };
@ -266,12 +204,9 @@ mod tests {
#[test] #[test]
fn test_as_ref() -> Result<(), Error> { fn test_as_ref() -> Result<(), Error> {
let val = TestState { let val = TestState { score: 3.0 };
score: 3.0,
max_generations: 10,
};
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let ref_value = genetic_node.as_ref().unwrap(); let ref_value = genetic_node.as_ref().unwrap();
@ -282,12 +217,9 @@ mod tests {
#[test] #[test]
fn test_id() -> Result<(), Error> { fn test_id() -> Result<(), Error> {
let val = TestState { let val = TestState { score: 3.0 };
score: 3.0,
max_generations: 10,
};
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let id_value = genetic_node.id(); let id_value = genetic_node.id();
@ -297,13 +229,23 @@ mod tests {
} }
#[test] #[test]
fn test_state() -> Result<(), Error> { fn test_max_generations() -> Result<(), Error> {
let val = TestState { let val = TestState { score: 3.0 };
score: 3.0,
max_generations: 10,
};
let uuid = Uuid::new_v4(); let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid); let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let max_generations = genetic_node.max_generations();
assert_eq!(max_generations, 10);
Ok(())
}
#[test]
fn test_state() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let state = genetic_node.state(); let state = genetic_node.state();
@ -312,16 +254,16 @@ mod tests {
Ok(()) Ok(())
} }
#[tokio::test] #[test]
async fn test_process_node() -> Result<(), Error> { fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(); let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
assert_eq!(genetic_node.state(), GeneticState::Initialize); assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate); assert_eq!(genetic_node.process_node()?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate); assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish); assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish); assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
Ok(()) Ok(())
} }

View file

@ -4,17 +4,15 @@
pub mod genetic_node; pub mod genetic_node;
use crate::{error::Error, tree::Tree}; use crate::{error::Error, tree::Tree};
use async_recursion::async_recursion; use file_linked::FileLinked;
use file_linked::{constants::data_format::DataFormat, FileLinked}; use futures::{future, future::BoxFuture};
use futures::future;
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState}; use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn}; use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{ use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path, collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
sync::Arc, time::Instant, time::Instant,
}; };
use tokio::{sync::RwLock, task::JoinHandle};
use uuid::Uuid; use uuid::Uuid;
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>; type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
@ -22,7 +20,7 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
/// Provides configuration options for managing a [`Gemla`] object as it executes. /// Provides configuration options for managing a [`Gemla`] object as it executes.
/// ///
/// # Examples /// # Examples
/// ```rust,ignore /// ```
/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] /// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
/// struct TestState { /// struct TestState {
/// pub score: f64, /// pub score: f64,
@ -57,6 +55,7 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
/// ``` /// ```
#[derive(Serialize, Deserialize, Copy, Clone)] #[derive(Serialize, Deserialize, Copy, Clone)]
pub struct GemlaConfig { pub struct GemlaConfig {
pub generations_per_node: u64,
pub overwrite: bool, pub overwrite: bool,
} }
@ -66,125 +65,79 @@ pub struct GemlaConfig {
/// individuals. /// individuals.
/// ///
/// [`GeneticNode`]: genetic_node::GeneticNode /// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<T> pub struct Gemla<'a, T>
where where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Clone, T: Serialize + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>, pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>, threads: HashMap<Uuid, BoxFuture<'a, Result<GeneticNodeWrapper<T>, Error>>>,
} }
impl<T: 'static> Gemla<T> impl<'a, T: 'a> Gemla<'a, T>
where where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone, T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{ {
pub async fn new( pub fn new(path: &Path, config: GemlaConfig) -> Result<Self, Error> {
path: &Path,
config: GemlaConfig,
data_format: DataFormat,
) -> Result<Self, Error> {
match File::open(path) { match File::open(path) {
// If the file exists we either want to overwrite the file or read from the file // If the file exists we either want to overwrite the file or read from the file
// based on the configuration provided // based on the configuration provided
Ok(_) => Ok(Gemla { Ok(_) => Ok(Gemla {
data: if config.overwrite { data: if config.overwrite {
FileLinked::new((None, config, T::Context::default()), path, data_format) FileLinked::new((None, config), path)?
.await?
} else { } else {
FileLinked::from_file(path, data_format)? FileLinked::from_file(path)?
}, },
threads: HashMap::new(), threads: HashMap::new(),
}), }),
// If the file doesn't exist we must create it // If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla { Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config, T::Context::default()), path, data_format) data: FileLinked::new((None, config), path)?,
.await?,
threads: HashMap::new(), threads: HashMap::new(),
}), }),
Err(error) => Err(Error::IO(error)), Err(error) => Err(Error::IO(error)),
} }
} }
pub fn tree_ref(&self) -> Arc<RwLock<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>> { pub fn tree_ref(&self) -> Option<&SimulationTree<T>> {
self.data.readonly().clone() self.data.readonly().0.as_ref()
} }
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> { pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
let tree_completed = {
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
};
if tree_completed {
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed // Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not. // in the tree and which nodes have not.
self.data self.data.mutate(|(d, c)| {
.mutate(|(d, _, _)| { let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
let mut tree: Option<SimulationTree<T>> =
Gemla::increase_height(d.take(), steps);
mem::swap(d, &mut tree); mem::swap(d, &mut tree);
}) })?;
.await?;
}
{
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
info!( info!(
"Height of simulation tree increased to {}", "Height of simulation tree increased to {}",
tree_ref self.tree_ref()
.map(|t| format!("{}", t.height())) .map(|t| format!("{}", t.height()))
.unwrap_or_else(|| "Tree is not defined".to_string()) .unwrap_or_else(|| "Tree is not defined".to_string())
); );
}
loop { loop {
let is_tree_processed;
{
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
}
// We need to keep simulating until the tree has been completely processed. // We need to keep simulating until the tree has been completely processed.
if is_tree_processed { if self
.tree_ref()
.map(|t| Gemla::is_completed(t))
.unwrap_or(false)
{
self.join_threads().await?; self.join_threads().await?;
info!("Processed tree"); info!("Processed tree");
break; break;
} }
let (node, gemla_context) = { if let Some(node) = self
let data_arc = self.data.readonly(); .tree_ref()
let data_ref = data_arc.read().await; .and_then(|t| self.get_unprocessed_node(t))
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context) {
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
(node, gemla_context.clone())
};
if let Some(node) = node {
trace!("Adding node to process list {}", node.id()); trace!("Adding node to process list {}", node.id());
let gemla_context = gemla_context.clone(); self.threads
.insert(node.id(), Box::pin(Gemla::process_node(node)));
self.threads.insert(
node.id(),
tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
);
} else { } else {
trace!("No node found to process, joining threads"); trace!("No node found to process, joining threads");
@ -200,26 +153,15 @@ where
trace!("Joining threads for nodes {:?}", self.threads.keys()); trace!("Joining threads for nodes {:?}", self.threads.keys());
let results = future::join_all(self.threads.values_mut()).await; let results = future::join_all(self.threads.values_mut()).await;
// Converting a list of results into a result wrapping the list // Converting a list of results into a result wrapping the list
let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> = let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> =
results.into_iter().flatten().collect(); results.into_iter().collect();
self.threads.clear(); self.threads.clear();
// We need to retrieve the processed nodes from the resulting list and replace them in the original list // We need to retrieve the processed nodes from the resulting list and replace them in the original list
match reduced_results { reduced_results.and_then(|r| {
Ok(r) => { self.data.mutate(|(d, _)| {
self.data if let Some(t) = d {
.mutate_async(|d| async move {
// Scope to limit the duration of the read lock
let (_, context) = {
let data_read = d.read().await;
(data_read.1, data_read.2.clone())
}; // Read lock is dropped here
let mut data_write = d.write().await;
if let Some(t) = data_write.0.as_mut() {
let failed_nodes = Gemla::replace_nodes(t, r); let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree // We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() { if !failed_nodes.is_empty() {
@ -230,26 +172,19 @@ where
} }
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes // Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t, context.clone()).await Gemla::merge_completed_nodes(t)
} else { } else {
warn!("Unable to replce nodes {:?} in empty tree", r); warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(()) Ok(())
} }
}) })?
.await??; })?;
}
Err(e) => return Err(e),
}
} }
Ok(()) Ok(())
} }
#[async_recursion] fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> {
async fn merge_completed_nodes<'a>(
tree: &'a mut SimulationTree<T>,
gemla_context: T::Context,
) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize { if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) { match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need // If the current node has been initialized, and has children nodes that are completed, then we need
@ -260,37 +195,43 @@ where
{ {
info!("Merging nodes {} and {}", l.val.id(), r.val.id()); info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) { if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
let merged_node = GeneticNode::merge( let merged_node = GeneticNode::merge(left_node, right_node)?;
left_node, tree.val = GeneticNodeWrapper::from(
right_node, *merged_node,
&tree.val.id(), tree.val.max_generations(),
gemla_context.clone(), tree.val.id(),
) );
.await?;
tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.id());
} }
} }
(Some(l), Some(r)) => { (Some(l), Some(r)) => {
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?; Gemla::merge_completed_nodes(l)?;
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?; Gemla::merge_completed_nodes(r)?;
} }
// If there is only one child node that's completed then we want to copy it to the parent node // If there is only one child node that's completed then we want to copy it to the parent node
(Some(l), None) if l.val.state() == GeneticState::Finish => { (Some(l), None) if l.val.state() == GeneticState::Finish => {
trace!("Copying node {}", l.val.id()); trace!("Copying node {}", l.val.id());
if let Some(left_node) = l.val.as_ref() { if let Some(left_node) = l.val.as_ref() {
GeneticNodeWrapper::from(left_node.clone(), tree.val.id()); GeneticNodeWrapper::from(
left_node.clone(),
tree.val.max_generations(),
tree.val.id(),
);
} }
} }
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?, (Some(l), None) => Gemla::merge_completed_nodes(l)?,
(None, Some(r)) if r.val.state() == GeneticState::Finish => { (None, Some(r)) if r.val.state() == GeneticState::Finish => {
trace!("Copying node {}", r.val.id()); trace!("Copying node {}", r.val.id());
if let Some(right_node) = r.val.as_ref() { if let Some(right_node) = r.val.as_ref() {
tree.val = GeneticNodeWrapper::from(right_node.clone(), tree.val.id()); tree.val = GeneticNodeWrapper::from(
right_node.clone(),
tree.val.max_generations(),
tree.val.id(),
);
} }
} }
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?, (None, Some(r)) => Gemla::merge_completed_nodes(r)?,
(_, _) => (), (_, _) => (),
} }
} }
@ -307,10 +248,7 @@ where
// during join_threads. // during join_threads.
(Some(l), Some(r)) (Some(l), Some(r))
if l.val.state() == GeneticState::Finish if l.val.state() == GeneticState::Finish
&& r.val.state() == GeneticState::Finish => && r.val.state() == GeneticState::Finish => Some(tree.val.clone()),
{
Some(tree.val.clone())
}
(Some(l), Some(r)) => self (Some(l), Some(r)) => self
.get_unprocessed_node(l) .get_unprocessed_node(l)
.or_else(|| self.get_unprocessed_node(r)), .or_else(|| self.get_unprocessed_node(r)),
@ -340,19 +278,25 @@ where
} }
} }
fn increase_height(tree: Option<SimulationTree<T>>, amount: u64) -> Option<SimulationTree<T>> { fn increase_height(
tree: Option<SimulationTree<T>>,
config: &GemlaConfig,
amount: u64,
) -> Option<SimulationTree<T>> {
if amount == 0 { if amount == 0 {
tree tree
} else { } else {
let left_branch_height = let left_branch_right =
tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1; tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1;
Some(Box::new(Tree::new( Some(Box::new(Tree::new(
GeneticNodeWrapper::new(), GeneticNodeWrapper::new(config.generations_per_node),
Gemla::increase_height(tree, amount - 1), Gemla::increase_height(tree, config, amount - 1),
// The right branch height has to equal the left branches total height // The right branch height has to equal the left branches total height
if left_branch_height > 0 { if left_branch_right > 0 {
Some(Box::new(btree!(GeneticNodeWrapper::new()))) Some(Box::new(btree!(GeneticNodeWrapper::new(
left_branch_right * config.generations_per_node
))))
} else { } else {
None None
}, },
@ -365,16 +309,13 @@ where
tree.val.state() == GeneticState::Finish tree.val.state() == GeneticState::Finish
} }
async fn process_node( async fn process_node(mut node: GeneticNodeWrapper<T>) -> Result<GeneticNodeWrapper<T>, Error> {
mut node: GeneticNodeWrapper<T>,
gemla_context: T::Context,
) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now(); let node_state_time = Instant::now();
let node_state = node.state(); let node_state = node.state();
node.process_node(gemla_context.clone()).await?; node.process_node()?;
info!( trace!(
"{:?} completed in {:?} for {}", "{:?} completed in {:?} for {}",
node_state, node_state,
node_state_time.elapsed(), node_state_time.elapsed(),
@ -392,13 +333,9 @@ where
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::core::*; use crate::core::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf; use std::path::PathBuf;
use tokio::runtime::Runtime; use std::fs;
use self::genetic_node::GeneticNodeContext;
struct CleanUp { struct CleanUp {
path: PathBuf, path: PathBuf,
@ -427,43 +364,23 @@ mod tests {
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)] #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState { struct TestState {
pub score: f64, pub score: f64,
pub max_generations: u64,
} }
#[async_trait]
impl genetic_node::GeneticNode for TestState { impl genetic_node::GeneticNode for TestState {
type Context = (); fn simulate(&mut self) -> Result<(), Error> {
async fn simulate(
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
self.score += 1.0; self.score += 1.0;
Ok(context.generation < self.max_generations)
}
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(()) Ok(())
} }
async fn initialize( fn mutate(&mut self) -> Result<(), Error> {
_context: GeneticNodeContext<Self::Context>, Ok(())
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState {
score: 0.0,
max_generations: 10,
}))
} }
async fn merge( fn initialize() -> Result<Box<TestState>, Error> {
left: &TestState, Ok(Box::new(TestState { score: 0.0 }))
right: &TestState, }
_id: &Uuid,
_: Self::Context, fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score { Ok(Box::new(if left.score > right.score {
left.clone() left.clone()
} else { } else {
@ -472,93 +389,66 @@ mod tests {
} }
} }
#[tokio::test] #[test]
async fn test_new() -> Result<(), Error> { fn test_new() -> Result<(), Error> {
let path = PathBuf::from("test_new_non_existing"); let path = PathBuf::from("test_new_non_existing");
// Use `spawn_blocking` to run synchronous code that needs to call async code internally. CleanUp::new(&path).run(|p| {
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
assert!(!path.exists()); assert!(!path.exists());
// Testing initial creation // Testing initial creation
let mut config = GemlaConfig { overwrite: true }; let mut config = GemlaConfig {
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?; generations_per_node: 1,
overwrite: true
};
let mut gemla = Gemla::<TestState>::new(&p, config)?;
// Now we can use `.await` within the spawned blocking task. smol::block_on(gemla.simulate(2))?;
gemla.simulate(2).await?; assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
let data = gemla.data.readonly();
let data_lock = data.read().await;
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
drop(data_lock);
drop(gemla); drop(gemla);
assert!(path.exists()); assert!(path.exists());
// Testing overwriting data // Testing overwriting data
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?; let mut gemla = Gemla::<TestState>::new(&p, config)?;
gemla.simulate(2).await?; smol::block_on(gemla.simulate(2))?;
let data = gemla.data.readonly(); assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
let data_lock = data.read().await;
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
drop(data_lock);
drop(gemla); drop(gemla);
assert!(path.exists()); assert!(path.exists());
// Testing not-overwriting data // Testing not-overwriting data
config.overwrite = false; config.overwrite = false;
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?; let mut gemla = Gemla::<TestState>::new(&p, config)?;
gemla.simulate(2).await?; smol::block_on(gemla.simulate(2))?;
let data = gemla.data.readonly(); assert_eq!(gemla.tree_ref().unwrap().height(), 4);
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 4);
drop(data_lock);
drop(gemla); drop(gemla);
assert!(path.exists()); assert!(path.exists());
Ok(()) Ok(())
}) })
})
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
} }
#[tokio::test] #[test]
async fn test_simulate() -> Result<(), Error> { fn test_simulate() -> Result<(), Error> {
let path = PathBuf::from("test_simulate"); let path = PathBuf::from("test_simulate");
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code. CleanUp::new(&path).run(|p| {
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
// Testing initial creation // Testing initial creation
let config = GemlaConfig { overwrite: true }; let config = GemlaConfig {
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?; generations_per_node: 10,
overwrite: true
};
let mut gemla = Gemla::<TestState>::new(&p, config)?;
// Now we can use `.await` within the spawned blocking task. smol::block_on(gemla.simulate(5))?;
gemla.simulate(5).await?; let tree = gemla.tree_ref().unwrap();
let data = gemla.data.readonly();
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 5); assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0); assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
Ok(()) Ok(())
}) })
}) }
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
} }

View file

@ -1,4 +1,5 @@
#[macro_use] #[macro_use]
pub mod tree; pub mod tree;
pub mod constants;
pub mod core; pub mod core;
pub mod error; pub mod error;

View file

@ -36,7 +36,7 @@ use std::cmp::max;
/// t.right = Some(Box::new(btree!(3))); /// t.right = Some(Box::new(btree!(3)));
/// assert_eq!(t.right.unwrap().val, 3); /// assert_eq!(t.right.unwrap().val, 3);
/// ``` /// ```
#[derive(Default, Serialize, Deserialize, PartialEq, Debug)] #[derive(Default, Serialize, Deserialize, Clone, PartialEq, Debug)]
pub struct Tree<T> { pub struct Tree<T> {
pub val: T, pub val: T,
pub left: Option<Box<Tree<T>>>, pub left: Option<Box<Tree<T>>>,

View file

@ -1,380 +0,0 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import matplotlib.ticker as ticker
# Simplified JSON data for demonstration
with open('gemla/round4.json', 'r') as file:
simplified_json_data = json.load(file)
# Function to traverse the tree to find a node id
def traverse_right_nodes(node):
if node is None:
return []
right_node = node.get("right")
left_node = node.get("left")
if right_node is None and left_node is None:
return []
elif right_node and left_node:
return [right_node] + traverse_right_nodes(left_node)
return []
# Getting most recent right graph
right_nodes = traverse_right_nodes(simplified_json_data[0])
# Heatmaps
# Data structure to store mutation rates, generations, and scores
mutation_rate_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by mutation rate and generation
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
minor_mutation_rate = node_val["minor_mutation_rate"]
generation = node_val["generation"]
# Ensure each score is associated with the correct generation
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
mutation_rate_data[minor_mutation_rate][gen_index].append(score)
# Prepare data for heatmap
max_generation = max(max(gens.keys()) for gens in mutation_rate_data.values())
heatmap_data = np.full((len(mutation_rate_data), max_generation + 1), np.nan)
# Populate the heatmap data with average scores
mutation_rates = sorted(mutation_rate_data.keys())
for i, mutation_rate in enumerate(mutation_rates):
for generation in range(max_generation + 1):
scores = mutation_rate_data[mutation_rate][generation]
if scores: # Check if there are scores for this generation
heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the heatmap
df_heatmap = pd.DataFrame(
data=heatmap_data,
index=mutation_rates,
columns=range(max_generation + 1)
)
# Data structure to store major mutation rates, generations, and scores
major_mutation_rate_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
major_mutation_rate = node_val["major_mutation_rate"]
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
major_mutation_rate_data[major_mutation_rate][gen_index].append(score)
# Prepare the heatmap data for major_mutation_rate similar to minor_mutation_rate
major_heatmap_data = np.full((len(major_mutation_rate_data), max_generation + 1), np.nan)
major_mutation_rates = sorted(major_mutation_rate_data.keys())
for i, major_rate in enumerate(major_mutation_rates):
for generation in range(max_generation + 1):
scores = major_mutation_rate_data[major_rate][generation]
if scores: # Check if there are scores for this generation
major_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_major_heatmap = pd.DataFrame(
data=major_heatmap_data,
index=major_mutation_rates,
columns=range(max_generation + 1)
)
# crossbreed_segments
# Data structure to store major mutation rates, generations, and scores
crossbreed_segments_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
crossbreed_segments = node_val["crossbreed_segments"]
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
crossbreed_segments_data[crossbreed_segments][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
crossbreed_heatmap_data = np.full((len(crossbreed_segments_data), max_generation + 1), np.nan)
crossbreed_segments = sorted(crossbreed_segments_data.keys())
for i, crossbreed_segment in enumerate(crossbreed_segments):
for generation in range(max_generation + 1):
scores = crossbreed_segments_data[crossbreed_segment][generation]
if scores: # Check if there are scores for this generation
crossbreed_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_crossbreed_heatmap = pd.DataFrame(
data=crossbreed_heatmap_data,
index=crossbreed_segments,
columns=range(max_generation + 1)
)
# mutation_weight_range
# Data structure to store major mutation rates, generations, and scores
mutation_weight_range_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
mutation_weight_range = node_val["mutation_weight_range"]
positive_extent = mutation_weight_range["end"]
negative_extent = -mutation_weight_range["start"]
mutation_weight_range = (positive_extent + negative_extent) / 2
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
mutation_weight_range_data[mutation_weight_range][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
mutation_weight_range_heatmap_data = np.full((len(mutation_weight_range_data), max_generation + 1), np.nan)
mutation_weight_ranges = sorted(mutation_weight_range_data.keys())
for i, mutation_weight_range in enumerate(mutation_weight_ranges):
for generation in range(max_generation + 1):
scores = mutation_weight_range_data[mutation_weight_range][generation]
if scores: # Check if there are scores for this generation
mutation_weight_range_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_mutation_weight_range_heatmap = pd.DataFrame(
data=mutation_weight_range_heatmap_data,
index=mutation_weight_ranges,
columns=range(max_generation + 1)
)
# weight_initialization_range
# Data structure to store major mutation rates, generations, and scores
weight_initialization_range_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
weight_initialization_range = node_val["weight_initialization_range"]
positive_extent = weight_initialization_range["end"]
negative_extent = -weight_initialization_range["start"]
weight_initialization_range = (positive_extent + negative_extent) / 2
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
weight_initialization_range_data[weight_initialization_range][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
weight_initialization_range_heatmap_data = np.full((len(weight_initialization_range_data), max_generation + 1), np.nan)
weight_initialization_ranges = sorted(weight_initialization_range_data.keys())
for i, weight_initialization_range in enumerate(weight_initialization_ranges):
for generation in range(max_generation + 1):
scores = weight_initialization_range_data[weight_initialization_range][generation]
if scores: # Check if there are scores for this generation
weight_initialization_range_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_weight_initialization_range_heatmap = pd.DataFrame(
data=weight_initialization_range_heatmap_data,
index=weight_initialization_ranges,
columns=range(max_generation + 1)
)
# weight_initialization_range_skew
# Data structure to store major mutation rates, generations, and scores
weight_initialization_range_skew_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
weight_initialization_range = node_val["weight_initialization_range"]
positive_extent = weight_initialization_range["end"]
negative_extent = -weight_initialization_range["start"]
weight_initialization_range_skew = (positive_extent - negative_extent) / 2
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
weight_initialization_range_skew_data[weight_initialization_range_skew][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
weight_initialization_range_skew_heatmap_data = np.full((len(weight_initialization_range_skew_data), max_generation + 1), np.nan)
weight_initialization_range_skews = sorted(weight_initialization_range_skew_data.keys())
for i, weight_initialization_range_skew in enumerate(weight_initialization_range_skews):
for generation in range(max_generation + 1):
scores = weight_initialization_range_skew_data[weight_initialization_range_skew][generation]
if scores: # Check if there are scores for this generation
weight_initialization_range_skew_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_weight_initialization_range_skew_heatmap = pd.DataFrame(
data=weight_initialization_range_skew_heatmap_data,
index=weight_initialization_range_skews,
columns=range(max_generation + 1)
)
# Analyze number of neurons correlation to score
# We can get the number of neurons via node_val["nn_shapes"] which contains an array of maps
# Each map has a key for the individual id and a value which is an array of integers representing the number of neurons in each layer
# We can use the individual id to get the score from the scores array
# We then generate a density map of the number of neurons vs the score
neuron_number_score_data = defaultdict(lambda: defaultdict(list))
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
nn_shapes = node_val["nn_shapes"]
# Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation)
for gen_index, score in enumerate(scores):
for individual_id, nn_shape in nn_shapes[gen_index].items():
neuron_number = sum(nn_shape)
# check if score has a value for the individual id
if individual_id not in score:
continue
neuron_number_score_data[neuron_number][gen_index].append(score[individual_id])
# prepare the density map data
neuron_number_score_heatmap_data = np.full((len(neuron_number_score_data), max_generation + 1), np.nan)
neuron_numbers = sorted(neuron_number_score_data.keys())
for i, neuron_number in enumerate(neuron_numbers):
for generation in range(max_generation + 1):
scores = neuron_number_score_data[neuron_number][generation]
if scores: # Check if there are scores for this generation
neuron_number_score_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_neuron_number_score_heatmap = pd.DataFrame(
data=neuron_number_score_heatmap_data,
index=neuron_numbers,
columns=range(max_generation + 1)
)
# Analyze number of layers correlation to score
nn_layers_score_data = defaultdict(lambda: defaultdict(list))
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
nn_shapes = node_val["nn_shapes"]
# Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation)
for gen_index, score in enumerate(scores):
for individual_id, nn_shape in nn_shapes[gen_index].items():
layer_number = len(nn_shape)
# check if score has a value for the individual id
if individual_id not in score:
continue
nn_layers_score_data[layer_number][gen_index].append(score[individual_id])
# prepare the density map data
nn_layers_score_heatmap_data = np.full((len(nn_layers_score_data), max_generation + 1), np.nan)
nn_layers = sorted(nn_layers_score_data.keys())
for i, nn_layer in enumerate(nn_layers):
for generation in range(max_generation + 1):
scores = nn_layers_score_data[nn_layer][generation]
if scores: # Check if there are scores for this generation
nn_layers_score_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_nn_layers_score_heatmap = pd.DataFrame(
data=nn_layers_score_heatmap_data,
index=nn_layers,
columns=range(max_generation + 1)
)
# print("Format: ", custom_formatter(0.123498761234, 0))
# Creating subplots
fig, axs = plt.subplots(2, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots
# Plotting the minor mutation rate heatmap
sns.heatmap(df_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 0])
# axs[0, 0].set_title('Minor Mutation Rate')
axs[0, 0].set_xlabel('Minor Mutation Rate')
axs[0, 0].set_ylabel('Generation')
axs[0, 0].invert_yaxis()
# Plotting the major mutation rate heatmap
sns.heatmap(df_major_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 1])
# axs[0, 1].set_title('Major Mutation Rate')
axs[0, 1].set_xlabel('Major Mutation Rate')
axs[0, 1].invert_yaxis()
# Plotting the crossbreed_segments heatmap
sns.heatmap(df_crossbreed_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 0])
# axs[1, 0].set_title('Crossbreed Segments')
axs[1, 0].set_xlabel('Crossbreed Segments')
axs[1, 0].set_ylabel('Generation')
axs[1, 0].invert_yaxis()
# Plotting the mutation_weight_range heatmap
sns.heatmap(df_mutation_weight_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 1])
# axs[1, 1].set_title('Mutation Weight Range')
axs[1, 1].set_xlabel('Mutation Weight Range')
axs[1, 1].invert_yaxis()
fig3, axs3 = plt.subplots(1, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots
# Plotting the weight_initialization_range heatmap
sns.heatmap(df_weight_initialization_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[0])
# axs[2, 0].set_title('Weight Initialization Range')
axs3[0].set_xlabel('Weight Initialization Range')
axs3[0].set_ylabel('Generation')
axs3[0].invert_yaxis()
# Plotting the weight_initialization_range_skew heatmap
sns.heatmap(df_weight_initialization_range_skew_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[1])
# axs[2, 1].set_title('Weight Initialization Range Skew')
axs3[1].set_xlabel('Weight Initialization Range Skew')
axs3[1].set_ylabel('Generation')
axs3[1].invert_yaxis()
# Creating a new window for the scatter plots
fig2, axs2 = plt.subplots(2, 1, figsize=(20, 14)) # Creates a 2x1 grid of subplots
# Plotting the neuron number vs score heatmap
sns.heatmap(df_neuron_number_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[1])
# axs[3, 1].set_title('Neuron Number vs. Score')
axs2[1].set_xlabel('Neuron Number')
axs2[1].set_ylabel('Generation')
axs2[1].invert_yaxis()
# Plotting the number of layers vs score heatmap
sns.heatmap(df_nn_layers_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[0])
# axs[3, 1].set_title('Number of Layers vs. Score')
axs2[0].set_xlabel('Number of Layers')
axs2[0].set_ylabel('Generation')
axs2[0].invert_yaxis()
# Display the plot
plt.tight_layout() # Adjusts the subplots to fit into the figure area.
plt.show()

12
sbcl_spike/main.lisp Normal file
View file

@ -0,0 +1,12 @@
;; Define a type that contains a population size and a population cutoff
(defclass simulation-node () ((population-size :initarg :population-size :accessor population-size)
(population-cutoff :initarg :population-cutoff :accessor population-cutoff)
(population :initform () :accessor population)))
;; Define a method that initializes population-size number of children in a population each with a random value
(defmethod initialize-instance :after ((node simulation-node) &key)
(setf (population node) (make-list (population-size node) :initial-element (random 100))))
(let ((node (make-instance 'simulation-node :population-size 100 :population-cutoff 10)))
(print (population-size node))
(population node))

View file

@ -1,118 +0,0 @@
import matplotlib.pyplot as plt
import networkx as nx
import subprocess
import tkinter as tk
from tkinter import filedialog
def select_file():
root = tk.Tk()
root.withdraw() # Hide the main window
file_path = filedialog.askopenfilename(
initialdir="/", # Set the initial directory to search for files
title="Select file",
filetypes=(("Net files", "*.net"), ("All files", "*.*"))
)
return file_path
def get_fann_data(network_file):
# Adjust the path to the Rust executable as needed
result = subprocess.run(['./extract_fann_data/target/debug/extract_fann_data.exe', network_file], capture_output=True, text=True)
if result.returncode != 0:
print("Error:", result.stderr)
return None, None
layer_sizes = []
connections = []
parsing_connections = False
for line in result.stdout.splitlines():
if line.startswith("Layers:"):
continue
elif line.startswith("Connections:"):
parsing_connections = True
continue
if parsing_connections:
from_neuron, to_neuron, weight = map(float, line.split())
connections.append((int(from_neuron), int(to_neuron), weight))
else:
layer_size, bias_count = map(int, line.split())
layer_sizes.append((layer_size, bias_count))
return layer_sizes, connections
def visualize_fann_network(network_file):
# Get network data
layer_sizes, connections = get_fann_data(network_file)
if layer_sizes is None or connections is None:
return # Error handling in get_fann_data should provide error output
# Create a directed graph
G = nx.DiGraph()
# Positions dictionary to hold the position of each neuron
pos = {}
node_count = 0
x_spacing = 1.0
y_spacing = 1.0
# Calculate the maximum layer size for proper spacing
max_layer_size = max(size for size, bias in layer_sizes)
# Build nodes and position them layer by layer from left to right
for layer_index, (layer_size, bias_count) in enumerate(layer_sizes):
y_positions = list(range(-layer_size-bias_count+1, 1, 1)) # Center-align vertically
y_positions = [y * (max_layer_size / (layer_size + bias_count)) * y_spacing for y in y_positions] # Adjust spacing
for neuron_index in range(layer_size + bias_count): # Include bias neurons
node_label = f"L{layer_index}N{neuron_index}"
G.add_node(node_count, label=node_label)
pos[node_count] = (layer_index * x_spacing, y_positions[neuron_index % len(y_positions)])
node_count += 1
# Add connections to the graph
for from_neuron, to_neuron, weight in connections:
G.add_edge(from_neuron, to_neuron, weight=weight)
max_weight = max(abs(weight) for _, _, weight in connections)
print(f"Max weight: {max_weight}")
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=200)
nx.draw_networkx_labels(G, pos, font_size=7)
# Custom function for edge properties
def adjust_properties(weight):
# if weight > 0:
# print("Weight:", weight)
color = 'green' if weight > 0 else 'red'
alpha = min((abs(weight) / max_weight) ** 3, 1)
# print(f"Color: {color}, Alpha: {alpha}")
return color, alpha
# Draw edges with custom properties
for u, v, d in G.edges(data=True):
color, alpha = adjust_properties(d['weight'])
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=color, alpha=alpha, width=1.5, arrows=False)
# Show plot
plt.title('FANN Network Visualization')
plt.axis('off') # Turn off the axis
plt.show()
# Path to the FANN network file
fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_4f2be613-ab26-4384-9a65-450e043984ea\\6\\4f2be613-ab26-4384-9a65-450e043984ea_fighter_nn_0.net'
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_fc294503-7b2a-40f8-be59-ccc486eb3f79\\0\\fc294503-7b2a-40f8-be59-ccc486eb3f79_fighter_nn_0.net"
# fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_99c30a7f-40ab-4faf-b16a-b44703fdb6cd\\0\\99c30a7f-40ab-4faf-b16a-b44703fdb6cd_fighter_nn_0.net'
# Has a 4 layer network
# # Generation 1
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\1\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 5
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\5\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 10
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\10\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 20
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\20\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 32
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\32\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
fann_path = select_file()
visualize_fann_network(fann_path)

View file

@ -1,104 +0,0 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
import networkx as nx
import random
def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
if not nx.is_tree(G):
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
if root is None:
if isinstance(G, nx.DiGraph):
root = next(iter(nx.topological_sort(G)))
else:
root = random.choice(list(G.nodes))
def _hierarchy_pos(G, root, width=2., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.successors(root)) # Use successors to get children for DiGraph
if not isinstance(G, nx.DiGraph):
if parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(G, child, width=dx*2.0, vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap, xcenter=nextx,
pos=pos, parent=root)
return pos
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
# Simplified JSON data for demonstration
with open('gemla/round4.json', 'r') as file:
simplified_json_data = json.load(file)
# Function to traverse the tree and create a graph
def traverse(node, graph, parent=None):
if node is None:
return
node_id = node["val"]["id"]
if "node" in node["val"] and node["val"]["node"]:
scores = node["val"]["node"]["scores"]
generations = node["val"]["node"]["generation"]
population_size = node["val"]["node"]["population_size"]
# Prepare to track the highest score across all generations and the corresponding individual
overall_max_score = float('-inf')
overall_max_score_individual = None
overall_max_score_gen = None
for gen, gen_scores in enumerate(scores):
if gen_scores: # Ensure the dictionary is not empty
# Find the max score and the individual for this generation
max_score_for_gen = max(gen_scores.values())
individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get)
# if max_score_for_gen > overall_max_score:
overall_max_score = max_score_for_gen
overall_max_score_individual = individual_with_max_score_for_gen
overall_max_score_gen = gen
# print debug statement
# print(f"Node {node_id}: Max score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})")
# print(f"Left: {node.get('left')}, Right: {node.get('right')}")
label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen + 1 if overall_max_score_gen is not None else 'N/A'})"
else:
label = node_id
graph.add_node(node_id, label=label)
if parent:
graph.add_edge(parent, node_id)
traverse(node.get("left"), graph, parent=node_id)
traverse(node.get("right"), graph, parent=node_id)
# Create a directed graph
G = nx.DiGraph()
# Populate the graph
traverse(simplified_json_data[0], G)
# Find the root node (a node with no incoming edges)
root_candidates = [node for node, indeg in G.in_degree() if indeg == 0]
if root_candidates:
root_node = root_candidates[0] # Assuming there's only one root candidate
else:
root_node = None # This should ideally never happen in a properly structured tree
# Use the determined root node for hierarchy_pos
if root_node is not None:
pos = hierarchy_pos(G, root=root_node)
labels = nx.get_node_attributes(G, 'label')
nx.draw(G, pos, labels=labels, with_labels=True, arrows=True)
plt.show()
else:
print("No root node found. Cannot draw the tree.")