Merge pull request 'dootcamp' (#1) from dootcamp into master

Reviewed-on: #1
This commit is contained in:
tepichord 2025-09-05 09:37:40 -07:00
commit 6725ab3feb
32 changed files with 5278 additions and 477 deletions

6
.gitignore vendored
View file

@ -13,4 +13,8 @@ settings.json
.DS_Store
.vscode/alive
.vscode/alive
# Added by cargo
/target

50
.vscode/launch.json vendored
View file

@ -10,7 +10,55 @@
"name": "Debug",
"program": "${workspaceFolder}/gemla/target/debug/gemla.exe",
"args": ["./gemla/temp/"],
"cwd": "${workspaceFolder}"
"cwd": "${workspaceFolder}/gemla"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug Rust Tests",
"cargo": {
"args": [
"test",
"--manifest-path", "${workspaceFolder}/gemla/Cargo.toml",
"--no-run", // Compiles the tests without running them
"--package=gemla", // Specify your package name if necessary
"--bin=bin"
],
"filter": { }
},
"args": [],
},
{
"type": "lldb",
"request": "launch",
"name": "Debug gemla Lib Tests",
"cargo": {
"args": [
"test",
"--manifest-path", "${workspaceFolder}/gemla/Cargo.toml",
"--no-run", // Compiles the tests without running them
"--package=gemla", // Specify your package name if necessary
"--lib"
],
"filter": { }
},
"args": [],
},
{
"type": "lldb",
"request": "launch",
"name": "Debug Rust FileLinked Tests",
"cargo": {
"args": [
"test",
"--manifest-path", "${workspaceFolder}/file_linked/Cargo.toml",
"--no-run", // Compiles the tests without running them
"--package=file_linked", // Specify your package name if necessary
"--lib"
],
"filter": { }
},
"args": [],
}
]
}

171
analyze_data.py Normal file
View file

@ -0,0 +1,171 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
# Simplified JSON data for demonstration
with open('gemla/round4.json', 'r') as file:
simplified_json_data = json.load(file)
target_node_id = '523f8250-3101-4586-90a1-127ffa6d73d9'
# Function to traverse the tree to find a node id
def traverse_left_nodes(node):
if node is None:
return []
left_node = node.get("left")
if left_node is None:
return [node]
return [node] + traverse_left_nodes(left_node)
# Function to traverse the tree to find a node id
def traverse_right_nodes(node):
if node is None:
return []
right_node = node.get("right")
left_node = node.get("left")
if right_node is None and left_node is None:
return []
elif right_node and left_node:
return [right_node] + traverse_right_nodes(left_node)
return []
# Getting the left graph
left_nodes = traverse_left_nodes(simplified_json_data[0])
left_nodes.reverse()
# print(node)
# Print properties available on the first node
node = left_nodes[0]
# print(node["val"].keys())
scores = []
for node in left_nodes:
# print(node)
# print(f'Node ID: {node["val"]["id"]}')
# print(f'Node scores length: {len(node["val"]["node"]["scores"])}')
if node["val"]["node"]:
node_scores = node["val"]["node"]["scores"]
if node_scores:
for score in node_scores:
scores.append(score)
# print(scores)
scores_values = [list(score_set.values()) for score_set in scores]
# Set up the figure for plotting on the same graph
fig, ax = plt.subplots(figsize=(10, 6))
# Generate a boxplot for each set of scores on the same graph
boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))])
# Set figure name to node id
ax.set_xscale('symlog', linthresh=1.0)
# Labeling
ax.set_xlabel(f'Scores - Main Line')
ax.set_ylabel('Score Sets')
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
# Set y-axis labels to be visible
ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))])
# Getting most recent right graph
right_nodes = traverse_right_nodes(simplified_json_data[0])
if len(right_nodes) != 0:
target_node_id = None
target_node = None
if target_node_id:
for node in right_nodes:
if node["val"]["id"] == target_node_id:
target_node = node
break
else:
target_node = right_nodes[0]
scores = target_node["val"]["node"]["scores"]
scores_values = [list(score_set.values()) for score_set in scores]
# Set up the figure for plotting on the same graph
fig, ax = plt.subplots(figsize=(10, 6))
# Generate a boxplot for each set of scores on the same graph
boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))])
ax.set_xscale('symlog', linthresh=1.0)
# Labeling
ax.set_xlabel(f'Scores: {target_node['val']['id']}')
ax.set_ylabel('Score Sets')
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
# Set y-axis labels to be visible
ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))])
# Find the highest scoring sets combining all scores and generations
scores = []
for node in left_nodes:
if node["val"]["node"]:
node_scores = node["val"]["node"]["scores"]
translated_node_scores = []
if node_scores:
for i in range(len(node_scores)):
for (individual, score) in node_scores[i].items():
translated_node_scores.append((node["val"]["id"], i, score))
scores.append(translated_node_scores)
# Add scores from the right nodes
if len(right_nodes) != 0:
for node in right_nodes:
if node["val"]["node"]:
node_scores = node["val"]["node"]["scores"]
translated_node_scores = []
if node_scores:
for i in range(len(node_scores)):
for (individual, score) in node_scores[i].items():
translated_node_scores.append((node["val"]["id"], i, score))
scores.append(translated_node_scores)
# Organize scores by individual and then by generation
individual_generation_scores = defaultdict(lambda: defaultdict(list))
for sublist in scores:
for id, generation, score in sublist:
individual_generation_scores[id][generation].append(score)
# Calculate Q3 for each individual's generation
individual_generation_q3 = {}
for id, generations in individual_generation_scores.items():
for gen, scores in generations.items():
individual_generation_q3[(id, gen)] = np.percentile(scores, 75)
# Sort by Q3 value, highest first, and select the top 20
top_20_individual_generations = sorted(individual_generation_q3, key=individual_generation_q3.get, reverse=True)[:40]
# Prepare scores for the top 20 for plotting
top_20_scores = [individual_generation_scores[id][gen] for id, gen in top_20_individual_generations]
# Adjust labels for clarity, indicating both the individual ID and generation
labels = [f'{id[:8]}... Gen {gen}' for id, gen in top_20_individual_generations]
# Generate box and whisker plots for the top 20 individual generations
fig, ax = plt.subplots(figsize=(12, 10))
ax.boxplot(top_20_scores, vert=False, patch_artist=True, labels=labels)
ax.set_xscale('symlog', linthresh=1.0)
ax.set_xlabel('Scores')
ax.set_ylabel('Individual Generation')
ax.set_title('Top 20 Individual Generations by Q3 Value')
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
# Display the plot
plt.show()

View file

@ -1 +0,0 @@
out/

View file

@ -1,10 +0,0 @@
(use Random)
(Project.config "title" "gemla")
(deftype SimulationNode [population-size Int, population-cutoff Int])
;; (let [test (SimulationNode.init 10 3)]
;; (do
;; (SimulationNode.set-population-size test 20)
;; (SimulationNode.population-size &test)
;; ))

View file

@ -0,0 +1,9 @@
[package]
name = "extract_fann_data"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
fann = "0.1.8"

View file

@ -0,0 +1,11 @@
fn main() {
// Replace this with the path to the directory containing `fann.lib`
let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib";
println!("cargo:rustc-link-search=native={}", lib_dir);
println!("cargo:rustc-link-lib=static=fann");
// Use `dylib=fann` instead of `static=fann` if you're linking dynamically
// If there are any additional directories where the compiler can find header files, you can specify them like this:
// println!("cargo:include={}", path_to_include_directory);
}

View file

@ -0,0 +1,38 @@
extern crate fann;
use fann::Fann;
use std::os::raw::c_uint;
fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() < 2 {
eprintln!("Usage: {} <network_file>", args[0]);
std::process::exit(1);
}
let network_file = &args[1];
match Fann::from_file(network_file) {
Ok(ann) => {
// Output layer sizes
let layer_sizes = ann.get_layer_sizes();
let bias_counts = ann.get_bias_counts();
println!("Layers:");
for (layer_size, bias_count) in layer_sizes.iter().zip(bias_counts.iter()) {
println!("{} {}", layer_size, bias_count);
}
// Output connections
println!("Connections:");
let connections = ann.get_connections();
for connection in connections {
println!("{} {} {}", connection.from_neuron, connection.to_neuron, connection.weight);
}
},
Err(err) => {
eprintln!("Error loading network from file {}: {}", network_file, err);
std::process::exit(1);
}
}
}

View file

@ -19,4 +19,7 @@ serde = { version = "1.0", features = ["derive"] }
thiserror = "1.0"
anyhow = "1.0"
bincode = "1.3.3"
log = "0.4.14"
log = "0.4.14"
serde_json = "1.0.114"
tokio = { version = "1.37.0", features = ["full"] }
futures = "0.3.30"

View file

@ -0,0 +1,5 @@
#[derive(Debug)]
pub enum DataFormat {
Bincode,
Json,
}

View file

@ -0,0 +1 @@
pub mod data_format;

View file

@ -1,8 +1,10 @@
//! A wrapper around an object that ties it to a physical file
pub mod constants;
pub mod error;
use anyhow::{anyhow, Context};
use constants::data_format::DataFormat;
use error::Error;
use log::info;
use serde::{de::DeserializeOwned, Serialize};
@ -10,9 +12,10 @@ use std::{
fs::{copy, remove_file, File},
io::{ErrorKind, Write},
path::{Path, PathBuf},
thread,
thread::JoinHandle,
sync::Arc,
thread::{self, JoinHandle},
};
use tokio::sync::RwLock;
/// A wrapper around an object `T` that ties the object to a physical file
#[derive(Debug)]
@ -20,10 +23,11 @@ pub struct FileLinked<T>
where
T: Serialize,
{
val: T,
val: Arc<RwLock<T>>,
path: PathBuf,
temp_file_path: PathBuf,
file_thread: Option<JoinHandle<()>>,
data_format: DataFormat,
}
impl<T> Drop for FileLinked<T>
@ -48,10 +52,12 @@ where
/// # Examples
/// ```
/// # use file_linked::*;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -60,27 +66,30 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() {
/// # #[tokio::main]
/// # async fn main() {
/// let test = Test {
/// a: 1,
/// b: String::from("two"),
/// c: 3.0
/// };
///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// assert_eq!(linked_test.readonly().b, String::from("two"));
/// assert_eq!(linked_test.readonly().c, 3.0);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// #
/// # drop(linked_test);
/// #
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn readonly(&self) -> &T {
&self.val
pub fn readonly(&self) -> Arc<RwLock<T>> {
self.val.clone()
}
/// Creates a new [`FileLinked`] object of type `T` stored to the file given by `path`.
@ -88,10 +97,12 @@ where
/// # Examples
/// ```
/// # use file_linked::*;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// #[derive(Deserialize, Serialize)]
/// struct Test {
@ -100,26 +111,29 @@ where
/// pub c: f64
/// }
///
/// # fn main() {
/// #[tokio::main]
/// # async fn main() {
/// let test = Test {
/// a: 1,
/// b: String::from("two"),
/// c: 3.0
/// };
///
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// assert_eq!(linked_test.readonly().b, String::from("two"));
/// assert_eq!(linked_test.readonly().c, 3.0);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// assert_eq!(readonly_ref.b, String::from("two"));
/// assert_eq!(readonly_ref.c, 3.0);
/// #
/// # drop(linked_test);
/// #
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
/// # }
/// ```
pub fn new(val: T, path: &Path) -> Result<FileLinked<T>, Error> {
pub async fn new(val: T, path: &Path, data_format: DataFormat) -> Result<FileLinked<T>, Error> {
let mut temp_file_path = path.to_path_buf();
temp_file_path.set_file_name(format!(
".temp{}",
@ -130,21 +144,28 @@ where
));
let mut result = FileLinked {
val,
val: Arc::new(RwLock::new(val)),
path: path.to_path_buf(),
temp_file_path,
file_thread: None,
data_format,
};
result.write_data()?;
result.write_data().await?;
Ok(result)
}
fn write_data(&mut self) -> Result<(), Error> {
async fn write_data(&mut self) -> Result<(), Error> {
let thread_path = self.path.clone();
let thread_temp_path = self.temp_file_path.clone();
let thread_val = bincode::serialize(&self.val)
.with_context(|| "Unable to serialize object into bincode".to_string())?;
let val = self.val.read().await;
let thread_val = match self.data_format {
DataFormat::Bincode => bincode::serialize(&*val)
.with_context(|| "Unable to serialize object into bincode".to_string())?,
DataFormat::Json => serde_json::to_vec(&*val)
.with_context(|| "Unable to serialize object into JSON".to_string())?,
};
if let Some(file_thread) = self.file_thread.take() {
file_thread
@ -190,10 +211,12 @@ where
/// ```
/// # use file_linked::*;
/// # use file_linked::error::Error;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -202,21 +225,28 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() -> Result<(), Error> {
/// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test {
/// a: 1,
/// b: String::from(""),
/// c: 0.0
/// };
///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// {
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
/// }
///
/// linked_test.mutate(|t| t.a = 2)?;
/// linked_test.mutate(|t| t.a = 2).await?;
///
/// assert_eq!(linked_test.readonly().a, 2);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// #
/// # drop(linked_test);
/// #
@ -225,10 +255,15 @@ where
/// # Ok(())
/// # }
/// ```
pub fn mutate<U, F: FnOnce(&mut T) -> U>(&mut self, op: F) -> Result<U, Error> {
let result = op(&mut self.val);
pub async fn mutate<U, F: FnOnce(&mut T) -> U>(&mut self, op: F) -> Result<U, Error> {
let val_clone = self.val.clone(); // Arc<RwLock<T>>
let mut val = val_clone.write().await; // RwLockWriteGuard<T>
self.write_data()?;
let result = op(&mut val);
drop(val);
self.write_data().await?;
Ok(result)
}
@ -239,10 +274,12 @@ where
/// ```
/// # use file_linked::*;
/// # use file_linked::error::Error;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -251,25 +288,30 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() -> Result<(), Error> {
/// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test {
/// a: 1,
/// b: String::from(""),
/// c: 0.0
/// };
///
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, 1);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 1);
///
/// linked_test.replace(Test {
/// a: 2,
/// b: String::from(""),
/// c: 0.0
/// })?;
/// }).await?;
///
/// assert_eq!(linked_test.readonly().a, 2);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, 2);
/// #
/// # drop(linked_test);
/// #
@ -278,10 +320,30 @@ where
/// # Ok(())
/// # }
/// ```
pub fn replace(&mut self, val: T) -> Result<(), Error> {
self.val = val;
pub async fn replace(&mut self, val: T) -> Result<(), Error> {
self.val = Arc::new(RwLock::new(val));
self.write_data()
self.write_data().await
}
}
impl<T> FileLinked<T>
where
T: Serialize + DeserializeOwned + Send + 'static,
{
/// Asynchronously modifies the data contained in a `FileLinked` object using an async callback `op`.
pub async fn mutate_async<F, Fut, U>(&mut self, op: F) -> Result<U, Error>
where
F: FnOnce(Arc<RwLock<T>>) -> Fut,
Fut: std::future::Future<Output = U> + Send,
U: Send,
{
let val_clone = self.val.clone();
let result = op(val_clone).await;
self.write_data().await?;
Ok(result)
}
}
@ -295,6 +357,7 @@ where
/// ```
/// # use file_linked::*;
/// # use file_linked::error::Error;
/// # use file_linked::constants::data_format::DataFormat;
/// # use serde::{Deserialize, Serialize};
/// # use std::fmt;
/// # use std::string::ToString;
@ -302,6 +365,7 @@ where
/// # use std::fs::OpenOptions;
/// # use std::io::Write;
/// # use std::path::PathBuf;
/// # use tokio;
/// #
/// # #[derive(Deserialize, Serialize)]
/// # struct Test {
@ -310,7 +374,8 @@ where
/// # pub c: f64
/// # }
/// #
/// # fn main() -> Result<(), Error> {
/// # #[tokio::main]
/// # async fn main() -> Result<(), Error> {
/// let test = Test {
/// a: 1,
/// b: String::from("2"),
@ -327,12 +392,14 @@ where
///
/// bincode::serialize_into(file, &test).expect("Unable to serialize object");
///
/// let mut linked_test = FileLinked::<Test>::from_file(&path)
/// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode)
/// .expect("Unable to create file linked object");
///
/// assert_eq!(linked_test.readonly().a, test.a);
/// assert_eq!(linked_test.readonly().b, test.b);
/// assert_eq!(linked_test.readonly().c, test.c);
/// let readonly = linked_test.readonly();
/// let readonly_ref = readonly.read().await;
/// assert_eq!(readonly_ref.a, test.a);
/// assert_eq!(readonly_ref.b, test.b);
/// assert_eq!(readonly_ref.c, test.c);
/// #
/// # drop(linked_test);
/// #
@ -341,7 +408,7 @@ where
/// # Ok(())
/// # }
/// ```
pub fn from_file(path: &Path) -> Result<FileLinked<T>, Error> {
pub fn from_file(path: &Path, data_format: DataFormat) -> Result<FileLinked<T>, Error> {
let mut temp_file_path = path.to_path_buf();
temp_file_path.set_file_name(format!(
".temp{}",
@ -351,16 +418,22 @@ where
.ok_or_else(|| anyhow!("Unable to get filename for tempfile {}", path.display()))?
));
match File::open(path).map_err(Error::from).and_then(|file| {
bincode::deserialize_from::<File, T>(file)
.with_context(|| format!("Unable to deserialize file {}", path.display()))
.map_err(Error::from)
}) {
match File::open(path)
.map_err(Error::from)
.and_then(|file| match data_format {
DataFormat::Bincode => bincode::deserialize_from::<File, T>(file)
.with_context(|| format!("Unable to deserialize file {}", path.display()))
.map_err(Error::from),
DataFormat::Json => serde_json::from_reader(file)
.with_context(|| format!("Unable to deserialize file {}", path.display()))
.map_err(Error::from),
}) {
Ok(val) => Ok(FileLinked {
val,
val: Arc::new(RwLock::new(val)),
path: path.to_path_buf(),
temp_file_path,
file_thread: None,
data_format,
}),
Err(err) => {
info!(
@ -370,30 +443,43 @@ where
);
// Try to use temp file instead and see if that file exists and is serializable
let val = FileLinked::from_temp_file(&temp_file_path, path)
let val = FileLinked::from_temp_file(&temp_file_path, path, &data_format)
.map_err(|_| err)
.with_context(|| format!("Failed to read/deserialize the object from the file {} and temp file {}", path.display(), temp_file_path.display()))?;
Ok(FileLinked {
val,
val: Arc::new(RwLock::new(val)),
path: path.to_path_buf(),
temp_file_path,
file_thread: None,
data_format,
})
}
}
}
fn from_temp_file(temp_file_path: &Path, path: &Path) -> Result<T, Error> {
fn from_temp_file(
temp_file_path: &Path,
path: &Path,
data_format: &DataFormat,
) -> Result<T, Error> {
let file = File::open(temp_file_path)
.with_context(|| format!("Unable to open file {}", temp_file_path.display()))?;
let val = bincode::deserialize_from(file).with_context(|| {
format!(
"Could not deserialize from temp file {}",
temp_file_path.display()
)
})?;
let val = match data_format {
DataFormat::Bincode => bincode::deserialize_from(file).with_context(|| {
format!(
"Could not deserialize from temp file {}",
temp_file_path.display()
)
})?,
DataFormat::Json => serde_json::from_reader(file).with_context(|| {
format!(
"Could not deserialize from temp file {}",
temp_file_path.display()
)
})?,
};
info!("Successfully deserialized value from temp file");
@ -421,8 +507,12 @@ mod tests {
}
}
pub fn run<F: FnOnce(&Path) -> Result<(), Error>>(&self, op: F) -> Result<(), Error> {
op(&self.path)
pub async fn run<F, Fut>(&self, op: F) -> ()
where
F: FnOnce(PathBuf) -> Fut,
Fut: std::future::Future<Output = ()>,
{
op(self.path.clone()).await
}
}
@ -434,92 +524,173 @@ mod tests {
}
}
#[test]
fn test_readonly() -> Result<(), Error> {
#[tokio::test]
async fn test_readonly() {
let path = PathBuf::from("test_readonly");
let cleanup = CleanUp::new(&path);
cleanup.run(|p| {
let val = vec!["one", "two", ""];
cleanup
.run(|p| async move {
let val = vec!["one", "two", ""];
let linked_object = FileLinked::new(val.clone(), &p)?;
assert_eq!(*linked_object.readonly(), val);
Ok(())
})
let linked_object = FileLinked::new(val.clone(), &p, DataFormat::Json)
.await
.expect("Unable to create file linked object");
let linked_object_arc = linked_object.readonly();
let linked_object_ref = linked_object_arc.read().await;
assert_eq!(*linked_object_ref, val);
})
.await;
}
#[test]
fn test_new() -> Result<(), Error> {
#[tokio::test]
async fn test_new() {
let path = PathBuf::from("test_new");
let cleanup = CleanUp::new(&path);
cleanup.run(|p| {
let val = "test";
cleanup
.run(|p| async move {
let val = "test";
FileLinked::new(val, &p)?;
FileLinked::new(val, &p, DataFormat::Bincode)
.await
.expect("Unable to create file linked object");
let file = File::open(&p)?;
let result: String =
bincode::deserialize_from(file).expect("Unable to deserialize from file");
assert_eq!(result, val);
Ok(())
})
let file = File::open(&p).expect("Unable to open file");
let result: String =
bincode::deserialize_from(file).expect("Unable to deserialize from file");
assert_eq!(result, val);
})
.await;
}
#[test]
fn test_mutate() -> Result<(), Error> {
#[tokio::test]
async fn test_mutate() {
let path = PathBuf::from("test_mutate");
let cleanup = CleanUp::new(&path);
cleanup.run(|p| {
let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, &p)?;
assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4]);
cleanup
.run(|p| async move {
let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json)
.await
.expect("Unable to create file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
file_linked_list.mutate(|v1| v1.push(5))?;
assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4, 5]);
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]);
file_linked_list.mutate(|v1| v1[1] = 1)?;
assert_eq!(*file_linked_list.readonly(), vec![1, 1, 3, 4, 5]);
drop(file_linked_list_ref);
file_linked_list
.mutate(|v1| v1.push(5))
.await
.expect("Error mutating file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
drop(file_linked_list);
Ok(())
})
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4, 5]);
drop(file_linked_list_ref);
file_linked_list
.mutate(|v1| v1[1] = 1)
.await
.expect("Error mutating file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]);
drop(file_linked_list);
})
.await;
}
#[test]
fn test_replace() -> Result<(), Error> {
#[tokio::test]
async fn test_async_mutate() {
let path = PathBuf::from("test_async_mutate");
let cleanup = CleanUp::new(&path);
cleanup
.run(|p| async move {
let list = vec![1, 2, 3, 4];
let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json)
.await
.expect("Unable to create file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]);
drop(file_linked_list_ref);
file_linked_list
.mutate_async(|v1| async move {
let mut v = v1.write().await;
v.push(5);
v[1] = 1;
Ok::<(), Error>(())
})
.await
.expect("Error mutating file linked object")
.expect("Error mutating file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]);
drop(file_linked_list);
})
.await;
}
#[tokio::test]
async fn test_replace() {
let path = PathBuf::from("test_replace");
let cleanup = CleanUp::new(&path);
cleanup.run(|p| {
let val1 = String::from("val1");
let val2 = String::from("val2");
let mut file_linked_list = FileLinked::new(val1.clone(), &p)?;
assert_eq!(*file_linked_list.readonly(), val1);
cleanup
.run(|p| async move {
let val1 = String::from("val1");
let val2 = String::from("val2");
let mut file_linked_list = FileLinked::new(val1.clone(), &p, DataFormat::Bincode)
.await
.expect("Unable to create file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
file_linked_list.replace(val2.clone())?;
assert_eq!(*file_linked_list.readonly(), val2);
assert_eq!(*file_linked_list_ref, val1);
drop(file_linked_list);
Ok(())
})
file_linked_list
.replace(val2.clone())
.await
.expect("Error replacing file linked object");
let file_linked_list_arc = file_linked_list.readonly();
let file_linked_list_ref = file_linked_list_arc.read().await;
assert_eq!(*file_linked_list_ref, val2);
drop(file_linked_list);
})
.await;
}
#[test]
fn test_from_file() -> Result<(), Error> {
#[tokio::test]
async fn test_from_file() {
let path = PathBuf::from("test_from_file");
let cleanup = CleanUp::new(&path);
cleanup.run(|p| {
let value: Vec<f64> = vec![2.0, 3.0, 5.0];
let file = File::create(&p)?;
cleanup
.run(|p| async move {
let value: Vec<f64> = vec![2.0, 3.0, 5.0];
let file = File::create(&p).expect("Unable to create file");
bincode::serialize_into(&file, &value).expect("Unable to serialize into file");
drop(file);
bincode::serialize_into(&file, &value).expect("Unable to serialize into file");
drop(file);
let linked_object: FileLinked<Vec<f64>> = FileLinked::from_file(&p)?;
assert_eq!(*linked_object.readonly(), value);
let linked_object: FileLinked<Vec<f64>> =
FileLinked::from_file(&p, DataFormat::Bincode)
.expect("Unable to create file linked object");
let linked_object_arc = linked_object.readonly();
let linked_object_ref = linked_object_arc.read().await;
drop(linked_object);
Ok(())
})
assert_eq!(*linked_object_ref, value);
drop(linked_object);
})
.await;
}
}

View file

@ -15,18 +15,22 @@ categories = ["simulation"]
[dependencies]
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
uuid = { version = "0.8", features = ["serde", "v4"] }
clap = { version = "~2.27.0", features = ["yaml"] }
toml = "0.5.8"
uuid = { version = "1.7", features = ["serde", "v4"] }
clap = { version = "4.5.2", features = ["derive"] }
toml = "0.8.10"
regex = "1"
file_linked = { version = "0.1.0", path = "../file_linked" }
thiserror = "1.0"
anyhow = "1.0"
rand = "0.8.4"
log = "0.4.14"
env_logger = "0.9.0"
futures = "0.3.17"
smol = "1.2.5"
smol-potat = "1.1.2"
num_cpus = "1.13.0"
easy-parallel = "3.1.0"
rand = "0.8.5"
log = "0.4.21"
env_logger = "0.11.3"
futures = "0.3.30"
tokio = { version = "1.37.0", features = ["full"] }
num_cpus = "1.16.0"
easy-parallel = "3.3.1"
fann = "0.1.8"
async-trait = "0.1.78"
async-recursion = "1.1.0"
lerp = "0.5.0"
console-subscriber = "0.2.0"

11
gemla/build.rs Normal file
View file

@ -0,0 +1,11 @@
fn main() {
// Replace this with the path to the directory containing `fann.lib`
let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib";
println!("cargo:rustc-link-search=native={}", lib_dir);
println!("cargo:rustc-link-lib=static=fann");
// Use `dylib=fann` instead of `static=fann` if you're linking dynamically
// If there are any additional directories where the compiler can find header files, you can specify them like this:
// println!("cargo:include={}", path_to_include_directory);
}

View file

@ -1,9 +0,0 @@
name: GEMLA
version: "0.1"
autor: Jacob VanDomelen <jacob.vandome15@gmail.com>
about: Uses a genetic algorithm to generate a machine learning algorithm.
args:
- FILE:
help: Sets the input/output file for the program.
required: true
index: 1

View file

@ -1,15 +0,0 @@
[[nodes]]
fabric_addr = "10.0.0.1:9999"
bridge_bind = "10.0.0.1:8888"
mem = "100 GiB"
cpu = 8
# [[nodes]]
# fabric_addr = "10.0.0.2:9999"
# mem = "100 GiB"
# cpu = 16
# [[nodes]]
# fabric_addr = "10.0.0.3:9999"
# mem = "100 GiB"
# cpu = 16

View file

@ -1,74 +1,72 @@
#[macro_use]
extern crate clap;
extern crate gemla;
#[macro_use]
extern crate log;
mod fighter_nn;
mod test_state;
use anyhow::anyhow;
use clap::App;
use easy_parallel::Parallel;
use anyhow::Result;
use clap::Parser;
use fighter_nn::FighterNN;
use file_linked::constants::data_format::DataFormat;
use gemla::{
constants::args::FILE,
core::{Gemla, GemlaConfig},
error::{log_error, Error},
error::log_error,
};
use smol::{channel, channel::RecvError, future, Executor};
use std::{path::PathBuf, time::Instant};
use test_state::TestState;
// const NUM_THREADS: usize = 2;
#[derive(Parser)]
#[command(version, about, long_about = None)]
struct Args {
/// The file to read/write the dataset from/to.
#[arg(short, long)]
file: String,
}
/// Runs a simluation of a genetic algorithm against a dataset.
///
/// Use the -h, --h, or --help flag to see usage syntax.
/// TODO
fn main() -> anyhow::Result<()> {
fn main() -> Result<()> {
env_logger::init();
info!("Starting");
// console_subscriber::init();
info!("Starting");
let now = Instant::now();
// Obtainning number of threads to use
let num_threads = num_cpus::get().max(1);
let ex = Executor::new();
let (signal, shutdown) = channel::unbounded::<()>();
// Manually configure the Tokio runtime
let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_cpus::get())
// .worker_threads(NUM_THREADS)
.build()?
.block_on(async {
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
let mut gemla = log_error(
Gemla::<FighterNN>::new(
&PathBuf::from(args.file),
GemlaConfig { overwrite: false },
DataFormat::Json,
)
.await,
)?;
// Create an executor thread pool.
let (_, result): (Vec<Result<(), RecvError>>, Result<(), Error>) = Parallel::new()
.each(0..num_threads, |_| {
future::block_on(ex.run(shutdown.recv()))
})
.finish(|| {
smol::block_on(async {
drop(signal);
// let gemla_arc = Arc::new(gemla);
// Command line arguments are parsed with the clap crate. And this program uses
// the yaml method with clap.
let yaml = load_yaml!("../../cli.yml");
let matches = App::from_yaml(yaml).get_matches();
// Setup your application logic here
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
// Checking that the first argument <FILE> is a valid file
if let Some(file_path) = matches.value_of(FILE) {
let mut gemla = log_error(Gemla::<TestState>::new(
&PathBuf::from(file_path),
GemlaConfig {
generations_per_node: 3,
overwrite: true,
},
))?;
log_error(gemla.simulate(3).await)?;
Ok(())
} else {
Err(Error::Other(anyhow!("Invalid argument for FILE")))
}
})
// Example placeholder loop to continuously run simulate
loop {
// Arbitrary loop count for demonstration
gemla.simulate(1).await?;
}
});
result?;
runtime?; // Handle errors from the block_on call
info!("Finished in {:?}", now.elapsed());
Ok(())
}

View file

@ -0,0 +1,79 @@
use std::sync::Arc;
use serde::ser::SerializeTuple;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use tokio::sync::Semaphore;
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50;
const VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT: usize = 1;
#[derive(Debug, Clone)]
pub struct FighterContext {
pub shared_semaphore: Arc<Semaphore>,
pub visible_simulations: Arc<Semaphore>,
}
impl Default for FighterContext {
fn default() -> Self {
FighterContext {
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)),
}
}
}
// Custom serialization to just output the concurrency limit.
impl Serialize for FighterContext {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
// Assuming the semaphore's available permits represent the concurrency limit.
// This part is tricky since Semaphore does not expose its initial permits.
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
let visible_concurrency_limit = VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT;
// serializer.serialize_u64(concurrency_limit as u64)
// Serialize the concurrency limit as a tuple
let mut state = serializer.serialize_tuple(2)?;
state.serialize_element(&concurrency_limit)?;
state.serialize_element(&visible_concurrency_limit)?;
state.end()
}
}
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
impl<'de> Deserialize<'de> for FighterContext {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
// Deserialize the tuple
let (_, _) = <(usize, usize)>::deserialize(deserializer)?;
Ok(FighterContext {
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serialization() {
let context = FighterContext::default();
let serialized = serde_json::to_string(&context).unwrap();
let deserialized: FighterContext = serde_json::from_str(&serialized).unwrap();
assert_eq!(
context.shared_semaphore.available_permits(),
deserialized.shared_semaphore.available_permits()
);
assert_eq!(
context.visible_simulations.available_permits(),
deserialized.visible_simulations.available_permits()
);
}
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,11 @@
use gemla::{core::genetic_node::GeneticNode, error::Error};
use async_trait::async_trait;
use gemla::{
core::genetic_node::{GeneticNode, GeneticNodeContext},
error::Error,
};
use rand::prelude::*;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
const POPULATION_SIZE: u64 = 5;
const POPULATION_REDUCTION_SIZE: u64 = 3;
@ -8,20 +13,30 @@ const POPULATION_REDUCTION_SIZE: u64 = 3;
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TestState {
pub population: Vec<i64>,
pub max_generations: u64,
}
#[async_trait]
impl GeneticNode for TestState {
fn initialize() -> Result<Box<Self>, Error> {
type Context = ();
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
let mut population: Vec<i64> = vec![];
for _ in 0..POPULATION_SIZE {
population.push(thread_rng().gen_range(0..100))
}
Ok(Box::new(TestState { population }))
Ok(Box::new(TestState {
population,
max_generations: 10,
}))
}
fn simulate(&mut self) -> Result<(), Error> {
async fn simulate(
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
let mut rng = thread_rng();
self.population = self
@ -30,10 +45,14 @@ impl GeneticNode for TestState {
.map(|p| p.saturating_add(rng.gen_range(-1..2)))
.collect();
Ok(())
if context.generation >= self.max_generations {
Ok(false)
} else {
Ok(true)
}
}
fn mutate(&mut self) -> Result<(), Error> {
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
let mut rng = thread_rng();
let mut v = self.population.clone();
@ -71,7 +90,12 @@ impl GeneticNode for TestState {
Ok(())
}
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
async fn merge(
left: &TestState,
right: &TestState,
id: &Uuid,
gemla_context: Self::Context,
) -> Result<Box<TestState>, Error> {
let mut v = left.population.clone();
v.append(&mut right.population.clone());
@ -80,9 +104,18 @@ impl GeneticNode for TestState {
v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec();
let mut result = TestState { population: v };
let mut result = TestState {
population: v,
max_generations: 10,
};
result.mutate()?;
result
.mutate(GeneticNodeContext {
id: *id,
generation: 0,
gemla_context,
})
.await?;
Ok(Box::new(result))
}
@ -93,57 +126,97 @@ mod tests {
use super::*;
use gemla::core::genetic_node::GeneticNode;
#[test]
fn test_initialize() {
let state = TestState::initialize().unwrap();
#[tokio::test]
async fn test_initialize() {
let state = TestState::initialize(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
#[test]
fn test_simulate() {
#[tokio::test]
async fn test_simulate() {
let mut state = TestState {
population: vec![1, 1, 2, 3],
max_generations: 1,
};
let original_population = state.population.clone();
state.simulate().unwrap();
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
state.simulate().unwrap();
state.simulate().unwrap();
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
state
.simulate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert!(original_population
.iter()
.zip(state.population.iter())
.all(|(&a, &b)| b >= a - 3 && b <= a + 6))
}
#[test]
fn test_mutate() {
#[tokio::test]
async fn test_mutate() {
let mut state = TestState {
population: vec![4, 3, 3],
max_generations: 1,
};
state.mutate().unwrap();
state
.mutate(GeneticNodeContext {
id: Uuid::new_v4(),
generation: 0,
gemla_context: (),
})
.await
.unwrap();
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
}
#[test]
fn test_merge() {
#[tokio::test]
async fn test_merge() {
let state1 = TestState {
population: vec![1, 2, 4, 5],
max_generations: 1,
};
let state2 = TestState {
population: vec![0, 1, 3, 7],
max_generations: 1,
};
let merged_state = TestState::merge(&state1, &state2).unwrap();
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ())
.await
.unwrap();
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
assert!(merged_state.population.iter().any(|&x| x == 7));

View file

@ -1,2 +0,0 @@
/// Corresponds to the FILE command line argument used in accordance with the clap crate.
pub const FILE: &str = "FILE";

View file

@ -1 +0,0 @@
pub mod args;

View file

@ -5,7 +5,9 @@
use crate::error::Error;
use anyhow::Context;
use serde::{Deserialize, Serialize};
use async_trait::async_trait;
use log::info;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug;
use uuid::Uuid;
@ -24,45 +26,65 @@ pub enum GeneticState {
Finish,
}
#[derive(Clone, Debug)]
pub struct GeneticNodeContext<S> {
pub generation: u64,
pub id: Uuid,
pub gemla_context: S,
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
///
/// [`Bracket`]: crate::bracket::Bracket
pub trait GeneticNode {
#[async_trait]
pub trait GeneticNode: Send {
type Context;
/// Initializes a new instance of a [`GeneticState`].
///
/// # Examples
/// TODO
fn initialize() -> Result<Box<Self>, Error>;
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
fn simulate(&mut self) -> Result<(), Error>;
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>)
-> Result<bool, Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
///
/// # Examples
/// TODO
fn mutate(&mut self) -> Result<(), Error>;
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
async fn merge(
left: &Self,
right: &Self,
id: &Uuid,
context: Self::Context,
) -> Result<Box<Self>, Error>;
}
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
/// well as signal recovery. Transition states are given by [`GeneticState`]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct GeneticNodeWrapper<T> {
pub struct GeneticNodeWrapper<T>
where
T: Clone,
{
node: Option<T>,
state: GeneticState,
generation: u64,
max_generations: u64,
id: Uuid,
}
impl<T> Default for GeneticNodeWrapper<T> {
impl<T> Default for GeneticNodeWrapper<T>
where
T: Clone,
{
fn default() -> Self {
GeneticNodeWrapper {
node: None,
state: GeneticState::Initialize,
generation: 1,
max_generations: 1,
id: Uuid::new_v4(),
}
}
@ -70,21 +92,20 @@ impl<T> Default for GeneticNodeWrapper<T> {
impl<T> GeneticNodeWrapper<T>
where
T: GeneticNode + Debug,
T: GeneticNode + Debug + Send + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub fn new(max_generations: u64) -> Self {
pub fn new() -> Self {
GeneticNodeWrapper::<T> {
max_generations,
..Default::default()
}
}
pub fn from(data: T, max_generations: u64, id: Uuid) -> Self {
pub fn from(data: T, id: Uuid) -> Self {
GeneticNodeWrapper {
node: Some(data),
state: GeneticState::Simulate,
generation: 1,
max_generations,
id,
}
}
@ -93,36 +114,51 @@ where
self.node.as_ref()
}
pub fn take(&mut self) -> Option<T> {
self.node.take()
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn max_generations(&self) -> u64 {
self.max_generations
pub fn generation(&self) -> u64 {
self.generation
}
pub fn state(&self) -> GeneticState {
self.state
}
pub fn process_node(&mut self) -> Result<GeneticState, Error> {
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
id: self.id,
gemla_context,
};
match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => {
self.node = Some(*T::initialize()?);
self.node = Some(*T::initialize(context.clone()).await?);
self.state = GeneticState::Simulate;
}
(GeneticState::Simulate, Some(n)) => {
n.simulate()
let next_generation = n
.simulate(context.clone())
.await
.with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = if self.generation >= self.max_generations {
GeneticState::Finish
} else {
info!("Simulation complete and continuing: {:?}", next_generation);
self.state = if next_generation {
GeneticState::Mutate
} else {
GeneticState::Finish
};
}
(GeneticState::Mutate, Some(n)) => {
n.mutate()
n.mutate(context.clone())
.await
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1;
@ -141,40 +177,64 @@ mod tests {
use super::*;
use crate::error::Error;
use anyhow::anyhow;
use async_trait::async_trait;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState {
pub score: f64,
pub max_generations: u64,
}
#[async_trait]
impl GeneticNode for TestState {
fn simulate(&mut self) -> Result<(), Error> {
type Context = ();
async fn simulate(
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
self.score += 1.0;
if context.generation >= self.max_generations {
Ok(false)
} else {
Ok(true)
}
}
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(())
}
fn mutate(&mut self) -> Result<(), Error> {
Ok(())
async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState {
score: 0.0,
max_generations: 2,
}))
}
fn initialize() -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
fn merge(_l: &TestState, _r: &TestState) -> Result<Box<TestState>, Error> {
async fn merge(
_l: &TestState,
_r: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge")))
}
}
#[test]
fn test_new() -> Result<(), Error> {
let genetic_node = GeneticNodeWrapper::<TestState>::new(10);
let genetic_node = GeneticNodeWrapper::<TestState>::new();
let other_genetic_node = GeneticNodeWrapper::<TestState> {
node: None,
state: GeneticState::Initialize,
generation: 1,
max_generations: 10,
id: genetic_node.id(),
};
@ -185,15 +245,17 @@ mod tests {
#[test]
fn test_from() -> Result<(), Error> {
let val = TestState { score: 0.0 };
let val = TestState {
score: 0.0,
max_generations: 10,
};
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
let other_genetic_node = GeneticNodeWrapper::<TestState> {
node: Some(val),
state: GeneticState::Simulate,
generation: 1,
max_generations: 10,
id: genetic_node.id(),
};
@ -204,9 +266,12 @@ mod tests {
#[test]
fn test_as_ref() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let val = TestState {
score: 3.0,
max_generations: 10,
};
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
let ref_value = genetic_node.as_ref().unwrap();
@ -217,9 +282,12 @@ mod tests {
#[test]
fn test_id() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let val = TestState {
score: 3.0,
max_generations: 10,
};
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
let id_value = genetic_node.id();
@ -228,24 +296,14 @@ mod tests {
Ok(())
}
#[test]
fn test_max_generations() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let max_generations = genetic_node.max_generations();
assert_eq!(max_generations, 10);
Ok(())
}
#[test]
fn test_state() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let val = TestState {
score: 3.0,
max_generations: 10,
};
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
let state = genetic_node.state();
@ -254,16 +312,16 @@ mod tests {
Ok(())
}
#[test]
fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
#[tokio::test]
async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new();
assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node()?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
Ok(())
}

View file

@ -4,42 +4,44 @@
pub mod genetic_node;
use crate::{error::Error, tree::Tree};
use file_linked::FileLinked;
use futures::{future, future::BoxFuture};
use async_recursion::async_recursion;
use file_linked::{constants::data_format::DataFormat, FileLinked};
use futures::future;
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
time::Instant,
sync::Arc, time::Instant,
};
use tokio::{sync::RwLock, task::JoinHandle};
use uuid::Uuid;
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
/// Provides configuration options for managing a [`Gemla`] object as it executes.
///
///
/// # Examples
/// ```
/// ```rust,ignore
/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
/// struct TestState {
/// pub score: f64,
/// }
///
///
/// impl genetic_node::GeneticNode for TestState {
/// fn simulate(&mut self) -> Result<(), Error> {
/// self.score += 1.0;
/// Ok(())
/// }
///
///
/// fn mutate(&mut self) -> Result<(), Error> {
/// Ok(())
/// }
///
///
/// fn initialize() -> Result<Box<TestState>, Error> {
/// Ok(Box::new(TestState { score: 0.0 }))
/// }
///
///
/// fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
/// Ok(Box::new(if left.score > right.score {
/// left.clone()
@ -48,14 +50,13 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
/// }))
/// }
/// }
///
///
/// fn main() {
///
/// }
/// ```
#[derive(Serialize, Deserialize, Copy, Clone)]
pub struct GemlaConfig {
pub generations_per_node: u64,
pub overwrite: bool,
}
@ -65,79 +66,125 @@ pub struct GemlaConfig {
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<'a, T>
pub struct Gemla<T>
where
T: Serialize + Clone,
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
threads: HashMap<Uuid, BoxFuture<'a, Result<GeneticNodeWrapper<T>, Error>>>,
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
}
impl<'a, T: 'a> Gemla<'a, T>
impl<T: 'static> Gemla<T>
where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub fn new(path: &Path, config: GemlaConfig) -> Result<Self, Error> {
pub async fn new(
path: &Path,
config: GemlaConfig,
data_format: DataFormat,
) -> Result<Self, Error> {
match File::open(path) {
// If the file exists we either want to overwrite the file or read from the file
// If the file exists we either want to overwrite the file or read from the file
// based on the configuration provided
Ok(_) => Ok(Gemla {
data: if config.overwrite {
FileLinked::new((None, config), path)?
FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?
} else {
FileLinked::from_file(path)?
FileLinked::from_file(path, data_format)?
},
threads: HashMap::new(),
}),
// If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config), path)?,
data: FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?,
threads: HashMap::new(),
}),
Err(error) => Err(Error::IO(error)),
}
}
pub fn tree_ref(&self) -> Option<&SimulationTree<T>> {
self.data.readonly().0.as_ref()
pub fn tree_ref(&self) -> Arc<RwLock<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>> {
self.data.readonly().clone()
}
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not.
self.data.mutate(|(d, c)| {
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
mem::swap(d, &mut tree);
})?;
let tree_completed = {
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
info!(
"Height of simulation tree increased to {}",
self.tree_ref()
.map(|t| format!("{}", t.height()))
.unwrap_or_else(|| "Tree is not defined".to_string())
);
tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
};
if tree_completed {
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not.
self.data
.mutate(|(d, _, _)| {
let mut tree: Option<SimulationTree<T>> =
Gemla::increase_height(d.take(), steps);
mem::swap(d, &mut tree);
})
.await?;
}
{
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
info!(
"Height of simulation tree increased to {}",
tree_ref
.map(|t| format!("{}", t.height()))
.unwrap_or_else(|| "Tree is not defined".to_string())
);
}
loop {
// We need to keep simulating until the tree has been completely processed.
if self
.tree_ref()
.map(|t| Gemla::is_completed(t))
.unwrap_or(false)
let is_tree_processed;
{
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
}
// We need to keep simulating until the tree has been completely processed.
if is_tree_processed {
self.join_threads().await?;
info!("Processed tree");
break;
}
if let Some(node) = self
.tree_ref()
.and_then(|t| self.get_unprocessed_node(t))
{
let (node, gemla_context) = {
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context)
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
(node, gemla_context.clone())
};
if let Some(node) = node {
trace!("Adding node to process list {}", node.id());
self.threads
.insert(node.id(), Box::pin(Gemla::process_node(node)));
let gemla_context = gemla_context.clone();
self.threads.insert(
node.id(),
tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
);
} else {
trace!("No node found to process, joining threads");
@ -153,38 +200,56 @@ where
trace!("Joining threads for nodes {:?}", self.threads.keys());
let results = future::join_all(self.threads.values_mut()).await;
// Converting a list of results into a result wrapping the list
let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> =
results.into_iter().collect();
results.into_iter().flatten().collect();
self.threads.clear();
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
reduced_results.and_then(|r| {
self.data.mutate(|(d, _)| {
if let Some(t) = d {
let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() {
warn!(
"Unable to find {:?} to replace in tree",
failed_nodes.iter().map(|n| n.id())
)
}
match reduced_results {
Ok(r) => {
self.data
.mutate_async(|d| async move {
// Scope to limit the duration of the read lock
let (_, context) = {
let data_read = d.read().await;
(data_read.1, data_read.2.clone())
}; // Read lock is dropped here
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t)
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
}
})?
})?;
let mut data_write = d.write().await;
if let Some(t) = data_write.0.as_mut() {
let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() {
warn!(
"Unable to find {:?} to replace in tree",
failed_nodes.iter().map(|n| n.id())
)
}
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t, context.clone()).await
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
}
})
.await??;
}
Err(e) => return Err(e),
}
}
Ok(())
}
fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> {
#[async_recursion]
async fn merge_completed_nodes<'a>(
tree: &'a mut SimulationTree<T>,
gemla_context: T::Context,
) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need
@ -195,43 +260,37 @@ where
{
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
let merged_node = GeneticNode::merge(left_node, right_node)?;
tree.val = GeneticNodeWrapper::from(
*merged_node,
tree.val.max_generations(),
tree.val.id(),
);
let merged_node = GeneticNode::merge(
left_node,
right_node,
&tree.val.id(),
gemla_context.clone(),
)
.await?;
tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.id());
}
}
(Some(l), Some(r)) => {
Gemla::merge_completed_nodes(l)?;
Gemla::merge_completed_nodes(r)?;
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
}
// If there is only one child node that's completed then we want to copy it to the parent node
(Some(l), None) if l.val.state() == GeneticState::Finish => {
trace!("Copying node {}", l.val.id());
if let Some(left_node) = l.val.as_ref() {
GeneticNodeWrapper::from(
left_node.clone(),
tree.val.max_generations(),
tree.val.id(),
);
GeneticNodeWrapper::from(left_node.clone(), tree.val.id());
}
}
(Some(l), None) => Gemla::merge_completed_nodes(l)?,
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
trace!("Copying node {}", r.val.id());
if let Some(right_node) = r.val.as_ref() {
tree.val = GeneticNodeWrapper::from(
right_node.clone(),
tree.val.max_generations(),
tree.val.id(),
);
tree.val = GeneticNodeWrapper::from(right_node.clone(), tree.val.id());
}
}
(None, Some(r)) => Gemla::merge_completed_nodes(r)?,
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
(_, _) => (),
}
}
@ -240,15 +299,18 @@ where
}
fn get_unprocessed_node(&self, tree: &SimulationTree<T>) -> Option<GeneticNodeWrapper<T>> {
// If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list
// If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list
// should be fine because we process the tree from bottom to top.
if tree.val.state() != GeneticState::Finish && !self.threads.contains_key(&tree.val.id()) {
match (&tree.left, &tree.right) {
// If the children are finished we can start processing the currrent node. The current node should be merged from the children already
// If the children are finished we can start processing the currrent node. The current node should be merged from the children already
// during join_threads.
(Some(l), Some(r))
if l.val.state() == GeneticState::Finish
&& r.val.state() == GeneticState::Finish => Some(tree.val.clone()),
&& r.val.state() == GeneticState::Finish =>
{
Some(tree.val.clone())
}
(Some(l), Some(r)) => self
.get_unprocessed_node(l)
.or_else(|| self.get_unprocessed_node(r)),
@ -278,25 +340,19 @@ where
}
}
fn increase_height(
tree: Option<SimulationTree<T>>,
config: &GemlaConfig,
amount: u64,
) -> Option<SimulationTree<T>> {
fn increase_height(tree: Option<SimulationTree<T>>, amount: u64) -> Option<SimulationTree<T>> {
if amount == 0 {
tree
} else {
let left_branch_right =
let left_branch_height =
tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1;
Some(Box::new(Tree::new(
GeneticNodeWrapper::new(config.generations_per_node),
Gemla::increase_height(tree, config, amount - 1),
GeneticNodeWrapper::new(),
Gemla::increase_height(tree, amount - 1),
// The right branch height has to equal the left branches total height
if left_branch_right > 0 {
Some(Box::new(btree!(GeneticNodeWrapper::new(
left_branch_right * config.generations_per_node
))))
if left_branch_height > 0 {
Some(Box::new(btree!(GeneticNodeWrapper::new())))
} else {
None
},
@ -306,16 +362,19 @@ where
fn is_completed(tree: &SimulationTree<T>) -> bool {
// If the current node is finished, then by convention the children should all be finished as well
tree.val.state() == GeneticState::Finish
tree.val.state() == GeneticState::Finish
}
async fn process_node(mut node: GeneticNodeWrapper<T>) -> Result<GeneticNodeWrapper<T>, Error> {
async fn process_node(
mut node: GeneticNodeWrapper<T>,
gemla_context: T::Context,
) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now();
let node_state = node.state();
node.process_node()?;
node.process_node(gemla_context.clone()).await?;
trace!(
info!(
"{:?} completed in {:?} for {}",
node_state,
node_state_time.elapsed(),
@ -333,9 +392,13 @@ where
#[cfg(test)]
mod tests {
use crate::core::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::fs;
use std::path::PathBuf;
use tokio::runtime::Runtime;
use self::genetic_node::GeneticNodeContext;
struct CleanUp {
path: PathBuf,
@ -364,23 +427,43 @@ mod tests {
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState {
pub score: f64,
pub max_generations: u64,
}
#[async_trait]
impl genetic_node::GeneticNode for TestState {
fn simulate(&mut self) -> Result<(), Error> {
type Context = ();
async fn simulate(
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
self.score += 1.0;
Ok(context.generation < self.max_generations)
}
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(())
}
fn mutate(&mut self) -> Result<(), Error> {
Ok(())
async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState {
score: 0.0,
max_generations: 10,
}))
}
fn initialize() -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
async fn merge(
left: &TestState,
right: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score {
left.clone()
} else {
@ -389,66 +472,93 @@ mod tests {
}
}
#[test]
fn test_new() -> Result<(), Error> {
#[tokio::test]
async fn test_new() -> Result<(), Error> {
let path = PathBuf::from("test_new_non_existing");
CleanUp::new(&path).run(|p| {
assert!(!path.exists());
// Use `spawn_blocking` to run synchronous code that needs to call async code internally.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
assert!(!path.exists());
// Testing initial creation
let mut config = GemlaConfig {
generations_per_node: 1,
overwrite: true
};
let mut gemla = Gemla::<TestState>::new(&p, config)?;
// Testing initial creation
let mut config = GemlaConfig { overwrite: true };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
smol::block_on(gemla.simulate(2))?;
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
drop(gemla);
assert!(path.exists());
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(2).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
// Testing overwriting data
let mut gemla = Gemla::<TestState>::new(&p, config)?;
drop(data_lock);
drop(gemla);
assert!(path.exists());
smol::block_on(gemla.simulate(2))?;
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
// Testing overwriting data
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
drop(gemla);
assert!(path.exists());
gemla.simulate(2).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
// Testing not-overwriting data
config.overwrite = false;
let mut gemla = Gemla::<TestState>::new(&p, config)?;
drop(data_lock);
drop(gemla);
assert!(path.exists());
smol::block_on(gemla.simulate(2))?;
assert_eq!(gemla.tree_ref().unwrap().height(), 4);
// Testing not-overwriting data
config.overwrite = false;
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
drop(gemla);
assert!(path.exists());
gemla.simulate(2).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 4);
Ok(())
drop(data_lock);
drop(gemla);
assert!(path.exists());
Ok(())
})
})
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
#[test]
fn test_simulate() -> Result<(), Error> {
#[tokio::test]
async fn test_simulate() -> Result<(), Error> {
let path = PathBuf::from("test_simulate");
CleanUp::new(&path).run(|p| {
// Testing initial creation
let config = GemlaConfig {
generations_per_node: 10,
overwrite: true
};
let mut gemla = Gemla::<TestState>::new(&p, config)?;
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
// Testing initial creation
let config = GemlaConfig { overwrite: true };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
smol::block_on(gemla.simulate(5))?;
let tree = gemla.tree_ref().unwrap();
assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(5).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
Ok(())
Ok(())
})
})
})
}
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
}

View file

@ -1,5 +1,4 @@
#[macro_use]
pub mod tree;
pub mod constants;
pub mod core;
pub mod error;

View file

@ -36,7 +36,7 @@ use std::cmp::max;
/// t.right = Some(Box::new(btree!(3)));
/// assert_eq!(t.right.unwrap().val, 3);
/// ```
#[derive(Default, Serialize, Deserialize, Clone, PartialEq, Debug)]
#[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
pub struct Tree<T> {
pub val: T,
pub left: Option<Box<Tree<T>>>,

380
parameter_analysis.py Normal file
View file

@ -0,0 +1,380 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import matplotlib.ticker as ticker
# Simplified JSON data for demonstration
with open('gemla/round4.json', 'r') as file:
simplified_json_data = json.load(file)
# Function to traverse the tree to find a node id
def traverse_right_nodes(node):
if node is None:
return []
right_node = node.get("right")
left_node = node.get("left")
if right_node is None and left_node is None:
return []
elif right_node and left_node:
return [right_node] + traverse_right_nodes(left_node)
return []
# Getting most recent right graph
right_nodes = traverse_right_nodes(simplified_json_data[0])
# Heatmaps
# Data structure to store mutation rates, generations, and scores
mutation_rate_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by mutation rate and generation
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
minor_mutation_rate = node_val["minor_mutation_rate"]
generation = node_val["generation"]
# Ensure each score is associated with the correct generation
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
mutation_rate_data[minor_mutation_rate][gen_index].append(score)
# Prepare data for heatmap
max_generation = max(max(gens.keys()) for gens in mutation_rate_data.values())
heatmap_data = np.full((len(mutation_rate_data), max_generation + 1), np.nan)
# Populate the heatmap data with average scores
mutation_rates = sorted(mutation_rate_data.keys())
for i, mutation_rate in enumerate(mutation_rates):
for generation in range(max_generation + 1):
scores = mutation_rate_data[mutation_rate][generation]
if scores: # Check if there are scores for this generation
heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the heatmap
df_heatmap = pd.DataFrame(
data=heatmap_data,
index=mutation_rates,
columns=range(max_generation + 1)
)
# Data structure to store major mutation rates, generations, and scores
major_mutation_rate_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
major_mutation_rate = node_val["major_mutation_rate"]
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
major_mutation_rate_data[major_mutation_rate][gen_index].append(score)
# Prepare the heatmap data for major_mutation_rate similar to minor_mutation_rate
major_heatmap_data = np.full((len(major_mutation_rate_data), max_generation + 1), np.nan)
major_mutation_rates = sorted(major_mutation_rate_data.keys())
for i, major_rate in enumerate(major_mutation_rates):
for generation in range(max_generation + 1):
scores = major_mutation_rate_data[major_rate][generation]
if scores: # Check if there are scores for this generation
major_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_major_heatmap = pd.DataFrame(
data=major_heatmap_data,
index=major_mutation_rates,
columns=range(max_generation + 1)
)
# crossbreed_segments
# Data structure to store major mutation rates, generations, and scores
crossbreed_segments_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
crossbreed_segments = node_val["crossbreed_segments"]
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
crossbreed_segments_data[crossbreed_segments][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
crossbreed_heatmap_data = np.full((len(crossbreed_segments_data), max_generation + 1), np.nan)
crossbreed_segments = sorted(crossbreed_segments_data.keys())
for i, crossbreed_segment in enumerate(crossbreed_segments):
for generation in range(max_generation + 1):
scores = crossbreed_segments_data[crossbreed_segment][generation]
if scores: # Check if there are scores for this generation
crossbreed_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_crossbreed_heatmap = pd.DataFrame(
data=crossbreed_heatmap_data,
index=crossbreed_segments,
columns=range(max_generation + 1)
)
# mutation_weight_range
# Data structure to store major mutation rates, generations, and scores
mutation_weight_range_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
mutation_weight_range = node_val["mutation_weight_range"]
positive_extent = mutation_weight_range["end"]
negative_extent = -mutation_weight_range["start"]
mutation_weight_range = (positive_extent + negative_extent) / 2
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
mutation_weight_range_data[mutation_weight_range][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
mutation_weight_range_heatmap_data = np.full((len(mutation_weight_range_data), max_generation + 1), np.nan)
mutation_weight_ranges = sorted(mutation_weight_range_data.keys())
for i, mutation_weight_range in enumerate(mutation_weight_ranges):
for generation in range(max_generation + 1):
scores = mutation_weight_range_data[mutation_weight_range][generation]
if scores: # Check if there are scores for this generation
mutation_weight_range_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_mutation_weight_range_heatmap = pd.DataFrame(
data=mutation_weight_range_heatmap_data,
index=mutation_weight_ranges,
columns=range(max_generation + 1)
)
# weight_initialization_range
# Data structure to store major mutation rates, generations, and scores
weight_initialization_range_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
weight_initialization_range = node_val["weight_initialization_range"]
positive_extent = weight_initialization_range["end"]
negative_extent = -weight_initialization_range["start"]
weight_initialization_range = (positive_extent + negative_extent) / 2
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
weight_initialization_range_data[weight_initialization_range][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
weight_initialization_range_heatmap_data = np.full((len(weight_initialization_range_data), max_generation + 1), np.nan)
weight_initialization_ranges = sorted(weight_initialization_range_data.keys())
for i, weight_initialization_range in enumerate(weight_initialization_ranges):
for generation in range(max_generation + 1):
scores = weight_initialization_range_data[weight_initialization_range][generation]
if scores: # Check if there are scores for this generation
weight_initialization_range_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_weight_initialization_range_heatmap = pd.DataFrame(
data=weight_initialization_range_heatmap_data,
index=weight_initialization_ranges,
columns=range(max_generation + 1)
)
# weight_initialization_range_skew
# Data structure to store major mutation rates, generations, and scores
weight_initialization_range_skew_data = defaultdict(lambda: defaultdict(list))
# Populate the dictionary with scores indexed by major mutation rate and generation
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
weight_initialization_range = node_val["weight_initialization_range"]
positive_extent = weight_initialization_range["end"]
negative_extent = -weight_initialization_range["start"]
weight_initialization_range_skew = (positive_extent - negative_extent) / 2
generation = node_val["generation"]
for gen_index, score_list in enumerate(scores):
for score in score_list.values():
weight_initialization_range_skew_data[weight_initialization_range_skew][gen_index].append(score)
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
weight_initialization_range_skew_heatmap_data = np.full((len(weight_initialization_range_skew_data), max_generation + 1), np.nan)
weight_initialization_range_skews = sorted(weight_initialization_range_skew_data.keys())
for i, weight_initialization_range_skew in enumerate(weight_initialization_range_skews):
for generation in range(max_generation + 1):
scores = weight_initialization_range_skew_data[weight_initialization_range_skew][generation]
if scores: # Check if there are scores for this generation
weight_initialization_range_skew_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_weight_initialization_range_skew_heatmap = pd.DataFrame(
data=weight_initialization_range_skew_heatmap_data,
index=weight_initialization_range_skews,
columns=range(max_generation + 1)
)
# Analyze number of neurons correlation to score
# We can get the number of neurons via node_val["nn_shapes"] which contains an array of maps
# Each map has a key for the individual id and a value which is an array of integers representing the number of neurons in each layer
# We can use the individual id to get the score from the scores array
# We then generate a density map of the number of neurons vs the score
neuron_number_score_data = defaultdict(lambda: defaultdict(list))
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
nn_shapes = node_val["nn_shapes"]
# Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation)
for gen_index, score in enumerate(scores):
for individual_id, nn_shape in nn_shapes[gen_index].items():
neuron_number = sum(nn_shape)
# check if score has a value for the individual id
if individual_id not in score:
continue
neuron_number_score_data[neuron_number][gen_index].append(score[individual_id])
# prepare the density map data
neuron_number_score_heatmap_data = np.full((len(neuron_number_score_data), max_generation + 1), np.nan)
neuron_numbers = sorted(neuron_number_score_data.keys())
for i, neuron_number in enumerate(neuron_numbers):
for generation in range(max_generation + 1):
scores = neuron_number_score_data[neuron_number][generation]
if scores: # Check if there are scores for this generation
neuron_number_score_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_neuron_number_score_heatmap = pd.DataFrame(
data=neuron_number_score_heatmap_data,
index=neuron_numbers,
columns=range(max_generation + 1)
)
# Analyze number of layers correlation to score
nn_layers_score_data = defaultdict(lambda: defaultdict(list))
for node in right_nodes:
node_val = node["val"]["node"]
if node_val:
scores = node_val["scores"]
nn_shapes = node_val["nn_shapes"]
# Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation)
for gen_index, score in enumerate(scores):
for individual_id, nn_shape in nn_shapes[gen_index].items():
layer_number = len(nn_shape)
# check if score has a value for the individual id
if individual_id not in score:
continue
nn_layers_score_data[layer_number][gen_index].append(score[individual_id])
# prepare the density map data
nn_layers_score_heatmap_data = np.full((len(nn_layers_score_data), max_generation + 1), np.nan)
nn_layers = sorted(nn_layers_score_data.keys())
for i, nn_layer in enumerate(nn_layers):
for generation in range(max_generation + 1):
scores = nn_layers_score_data[nn_layer][generation]
if scores: # Check if there are scores for this generation
nn_layers_score_heatmap_data[i, generation] = np.mean(scores)
# Creating a DataFrame for the major mutation rate heatmap
df_nn_layers_score_heatmap = pd.DataFrame(
data=nn_layers_score_heatmap_data,
index=nn_layers,
columns=range(max_generation + 1)
)
# print("Format: ", custom_formatter(0.123498761234, 0))
# Creating subplots
fig, axs = plt.subplots(2, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots
# Plotting the minor mutation rate heatmap
sns.heatmap(df_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 0])
# axs[0, 0].set_title('Minor Mutation Rate')
axs[0, 0].set_xlabel('Minor Mutation Rate')
axs[0, 0].set_ylabel('Generation')
axs[0, 0].invert_yaxis()
# Plotting the major mutation rate heatmap
sns.heatmap(df_major_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 1])
# axs[0, 1].set_title('Major Mutation Rate')
axs[0, 1].set_xlabel('Major Mutation Rate')
axs[0, 1].invert_yaxis()
# Plotting the crossbreed_segments heatmap
sns.heatmap(df_crossbreed_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 0])
# axs[1, 0].set_title('Crossbreed Segments')
axs[1, 0].set_xlabel('Crossbreed Segments')
axs[1, 0].set_ylabel('Generation')
axs[1, 0].invert_yaxis()
# Plotting the mutation_weight_range heatmap
sns.heatmap(df_mutation_weight_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 1])
# axs[1, 1].set_title('Mutation Weight Range')
axs[1, 1].set_xlabel('Mutation Weight Range')
axs[1, 1].invert_yaxis()
fig3, axs3 = plt.subplots(1, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots
# Plotting the weight_initialization_range heatmap
sns.heatmap(df_weight_initialization_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[0])
# axs[2, 0].set_title('Weight Initialization Range')
axs3[0].set_xlabel('Weight Initialization Range')
axs3[0].set_ylabel('Generation')
axs3[0].invert_yaxis()
# Plotting the weight_initialization_range_skew heatmap
sns.heatmap(df_weight_initialization_range_skew_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[1])
# axs[2, 1].set_title('Weight Initialization Range Skew')
axs3[1].set_xlabel('Weight Initialization Range Skew')
axs3[1].set_ylabel('Generation')
axs3[1].invert_yaxis()
# Creating a new window for the scatter plots
fig2, axs2 = plt.subplots(2, 1, figsize=(20, 14)) # Creates a 2x1 grid of subplots
# Plotting the neuron number vs score heatmap
sns.heatmap(df_neuron_number_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[1])
# axs[3, 1].set_title('Neuron Number vs. Score')
axs2[1].set_xlabel('Neuron Number')
axs2[1].set_ylabel('Generation')
axs2[1].invert_yaxis()
# Plotting the number of layers vs score heatmap
sns.heatmap(df_nn_layers_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[0])
# axs[3, 1].set_title('Number of Layers vs. Score')
axs2[0].set_xlabel('Number of Layers')
axs2[0].set_ylabel('Generation')
axs2[0].invert_yaxis()
# Display the plot
plt.tight_layout() # Adjusts the subplots to fit into the figure area.
plt.show()

View file

@ -1,12 +0,0 @@
;; Define a type that contains a population size and a population cutoff
(defclass simulation-node () ((population-size :initarg :population-size :accessor population-size)
(population-cutoff :initarg :population-cutoff :accessor population-cutoff)
(population :initform () :accessor population)))
;; Define a method that initializes population-size number of children in a population each with a random value
(defmethod initialize-instance :after ((node simulation-node) &key)
(setf (population node) (make-list (population-size node) :initial-element (random 100))))
(let ((node (make-instance 'simulation-node :population-size 100 :population-cutoff 10)))
(print (population-size node))
(population node))

118
visualize_networks.py Normal file
View file

@ -0,0 +1,118 @@
import matplotlib.pyplot as plt
import networkx as nx
import subprocess
import tkinter as tk
from tkinter import filedialog
def select_file():
root = tk.Tk()
root.withdraw() # Hide the main window
file_path = filedialog.askopenfilename(
initialdir="/", # Set the initial directory to search for files
title="Select file",
filetypes=(("Net files", "*.net"), ("All files", "*.*"))
)
return file_path
def get_fann_data(network_file):
# Adjust the path to the Rust executable as needed
result = subprocess.run(['./extract_fann_data/target/debug/extract_fann_data.exe', network_file], capture_output=True, text=True)
if result.returncode != 0:
print("Error:", result.stderr)
return None, None
layer_sizes = []
connections = []
parsing_connections = False
for line in result.stdout.splitlines():
if line.startswith("Layers:"):
continue
elif line.startswith("Connections:"):
parsing_connections = True
continue
if parsing_connections:
from_neuron, to_neuron, weight = map(float, line.split())
connections.append((int(from_neuron), int(to_neuron), weight))
else:
layer_size, bias_count = map(int, line.split())
layer_sizes.append((layer_size, bias_count))
return layer_sizes, connections
def visualize_fann_network(network_file):
# Get network data
layer_sizes, connections = get_fann_data(network_file)
if layer_sizes is None or connections is None:
return # Error handling in get_fann_data should provide error output
# Create a directed graph
G = nx.DiGraph()
# Positions dictionary to hold the position of each neuron
pos = {}
node_count = 0
x_spacing = 1.0
y_spacing = 1.0
# Calculate the maximum layer size for proper spacing
max_layer_size = max(size for size, bias in layer_sizes)
# Build nodes and position them layer by layer from left to right
for layer_index, (layer_size, bias_count) in enumerate(layer_sizes):
y_positions = list(range(-layer_size-bias_count+1, 1, 1)) # Center-align vertically
y_positions = [y * (max_layer_size / (layer_size + bias_count)) * y_spacing for y in y_positions] # Adjust spacing
for neuron_index in range(layer_size + bias_count): # Include bias neurons
node_label = f"L{layer_index}N{neuron_index}"
G.add_node(node_count, label=node_label)
pos[node_count] = (layer_index * x_spacing, y_positions[neuron_index % len(y_positions)])
node_count += 1
# Add connections to the graph
for from_neuron, to_neuron, weight in connections:
G.add_edge(from_neuron, to_neuron, weight=weight)
max_weight = max(abs(weight) for _, _, weight in connections)
print(f"Max weight: {max_weight}")
# Draw nodes
nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=200)
nx.draw_networkx_labels(G, pos, font_size=7)
# Custom function for edge properties
def adjust_properties(weight):
# if weight > 0:
# print("Weight:", weight)
color = 'green' if weight > 0 else 'red'
alpha = min((abs(weight) / max_weight) ** 3, 1)
# print(f"Color: {color}, Alpha: {alpha}")
return color, alpha
# Draw edges with custom properties
for u, v, d in G.edges(data=True):
color, alpha = adjust_properties(d['weight'])
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=color, alpha=alpha, width=1.5, arrows=False)
# Show plot
plt.title('FANN Network Visualization')
plt.axis('off') # Turn off the axis
plt.show()
# Path to the FANN network file
fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_4f2be613-ab26-4384-9a65-450e043984ea\\6\\4f2be613-ab26-4384-9a65-450e043984ea_fighter_nn_0.net'
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_fc294503-7b2a-40f8-be59-ccc486eb3f79\\0\\fc294503-7b2a-40f8-be59-ccc486eb3f79_fighter_nn_0.net"
# fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_99c30a7f-40ab-4faf-b16a-b44703fdb6cd\\0\\99c30a7f-40ab-4faf-b16a-b44703fdb6cd_fighter_nn_0.net'
# Has a 4 layer network
# # Generation 1
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\1\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 5
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\5\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 10
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\10\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 20
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\20\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
# # Generation 32
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\32\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
fann_path = select_file()
visualize_fann_network(fann_path)

View file

@ -0,0 +1,104 @@
# Re-importing necessary libraries
import json
import matplotlib.pyplot as plt
import networkx as nx
import random
def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
if not nx.is_tree(G):
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
if root is None:
if isinstance(G, nx.DiGraph):
root = next(iter(nx.topological_sort(G)))
else:
root = random.choice(list(G.nodes))
def _hierarchy_pos(G, root, width=2., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
if pos is None:
pos = {root: (xcenter, vert_loc)}
else:
pos[root] = (xcenter, vert_loc)
children = list(G.successors(root)) # Use successors to get children for DiGraph
if not isinstance(G, nx.DiGraph):
if parent is not None:
children.remove(parent)
if len(children) != 0:
dx = width / len(children)
nextx = xcenter - width / 2 - dx / 2
for child in children:
nextx += dx
pos = _hierarchy_pos(G, child, width=dx*2.0, vert_gap=vert_gap,
vert_loc=vert_loc - vert_gap, xcenter=nextx,
pos=pos, parent=root)
return pos
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
# Simplified JSON data for demonstration
with open('gemla/round4.json', 'r') as file:
simplified_json_data = json.load(file)
# Function to traverse the tree and create a graph
def traverse(node, graph, parent=None):
if node is None:
return
node_id = node["val"]["id"]
if "node" in node["val"] and node["val"]["node"]:
scores = node["val"]["node"]["scores"]
generations = node["val"]["node"]["generation"]
population_size = node["val"]["node"]["population_size"]
# Prepare to track the highest score across all generations and the corresponding individual
overall_max_score = float('-inf')
overall_max_score_individual = None
overall_max_score_gen = None
for gen, gen_scores in enumerate(scores):
if gen_scores: # Ensure the dictionary is not empty
# Find the max score and the individual for this generation
max_score_for_gen = max(gen_scores.values())
individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get)
# if max_score_for_gen > overall_max_score:
overall_max_score = max_score_for_gen
overall_max_score_individual = individual_with_max_score_for_gen
overall_max_score_gen = gen
# print debug statement
# print(f"Node {node_id}: Max score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})")
# print(f"Left: {node.get('left')}, Right: {node.get('right')}")
label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen + 1 if overall_max_score_gen is not None else 'N/A'})"
else:
label = node_id
graph.add_node(node_id, label=label)
if parent:
graph.add_edge(parent, node_id)
traverse(node.get("left"), graph, parent=node_id)
traverse(node.get("right"), graph, parent=node_id)
# Create a directed graph
G = nx.DiGraph()
# Populate the graph
traverse(simplified_json_data[0], G)
# Find the root node (a node with no incoming edges)
root_candidates = [node for node, indeg in G.in_degree() if indeg == 0]
if root_candidates:
root_node = root_candidates[0] # Assuming there's only one root candidate
else:
root_node = None # This should ideally never happen in a properly structured tree
# Use the determined root node for hierarchy_pos
if root_node is not None:
pos = hierarchy_pos(G, root=root_node)
labels = nx.get_node_attributes(G, 'label')
nx.draw(G, pos, labels=labels, with_labels=True, arrows=True)
plt.show()
else:
print("No root node found. Cannot draw the tree.")