Compare commits
27 commits
e8d373d4f9
...
6725ab3feb
Author | SHA1 | Date | |
---|---|---|---|
6725ab3feb | |||
![]() |
7156d6d733 | ||
![]() |
4efba94ff4 | ||
![]() |
e2be40c318 | ||
![]() |
822df77f62 | ||
![]() |
98803b3700 | ||
![]() |
a11def630a | ||
![]() |
7a1f82ac63 | ||
![]() |
b56e37d411 | ||
![]() |
05c7dcbe11 | ||
![]() |
0ccd824ee6 | ||
![]() |
95699bd47e | ||
![]() |
d473970325 | ||
![]() |
be02823c4c | ||
![]() |
5670649227 | ||
![]() |
1301d457a9 | ||
![]() |
ac71b28c7c | ||
![]() |
97086fdbe0 | ||
![]() |
c9b746e59d | ||
![]() |
ca3989421d | ||
![]() |
69b026593e | ||
![]() |
7ffd48f186 | ||
![]() |
774a0df5d7 | ||
![]() |
d21d0fcd3a | ||
![]() |
c44c389fbe | ||
![]() |
371e78fe4c | ||
![]() |
f0fdaa7af9 |
32 changed files with 5278 additions and 477 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -13,4 +13,8 @@ settings.json
|
|||
|
||||
.DS_Store
|
||||
|
||||
.vscode/alive
|
||||
.vscode/alive
|
||||
|
||||
# Added by cargo
|
||||
|
||||
/target
|
||||
|
|
50
.vscode/launch.json
vendored
50
.vscode/launch.json
vendored
|
@ -10,7 +10,55 @@
|
|||
"name": "Debug",
|
||||
"program": "${workspaceFolder}/gemla/target/debug/gemla.exe",
|
||||
"args": ["./gemla/temp/"],
|
||||
"cwd": "${workspaceFolder}"
|
||||
"cwd": "${workspaceFolder}/gemla"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug Rust Tests",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--manifest-path", "${workspaceFolder}/gemla/Cargo.toml",
|
||||
"--no-run", // Compiles the tests without running them
|
||||
"--package=gemla", // Specify your package name if necessary
|
||||
"--bin=bin"
|
||||
],
|
||||
"filter": { }
|
||||
},
|
||||
"args": [],
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug gemla Lib Tests",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--manifest-path", "${workspaceFolder}/gemla/Cargo.toml",
|
||||
"--no-run", // Compiles the tests without running them
|
||||
"--package=gemla", // Specify your package name if necessary
|
||||
"--lib"
|
||||
],
|
||||
"filter": { }
|
||||
},
|
||||
"args": [],
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug Rust FileLinked Tests",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--manifest-path", "${workspaceFolder}/file_linked/Cargo.toml",
|
||||
"--no-run", // Compiles the tests without running them
|
||||
"--package=file_linked", // Specify your package name if necessary
|
||||
"--lib"
|
||||
],
|
||||
"filter": { }
|
||||
},
|
||||
"args": [],
|
||||
}
|
||||
]
|
||||
}
|
171
analyze_data.py
Normal file
171
analyze_data.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
# Re-importing necessary libraries
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
# Simplified JSON data for demonstration
|
||||
with open('gemla/round4.json', 'r') as file:
|
||||
simplified_json_data = json.load(file)
|
||||
|
||||
target_node_id = '523f8250-3101-4586-90a1-127ffa6d73d9'
|
||||
|
||||
# Function to traverse the tree to find a node id
|
||||
def traverse_left_nodes(node):
|
||||
if node is None:
|
||||
return []
|
||||
|
||||
left_node = node.get("left")
|
||||
if left_node is None:
|
||||
return [node]
|
||||
|
||||
return [node] + traverse_left_nodes(left_node)
|
||||
|
||||
# Function to traverse the tree to find a node id
|
||||
def traverse_right_nodes(node):
|
||||
if node is None:
|
||||
return []
|
||||
|
||||
right_node = node.get("right")
|
||||
left_node = node.get("left")
|
||||
|
||||
if right_node is None and left_node is None:
|
||||
return []
|
||||
elif right_node and left_node:
|
||||
return [right_node] + traverse_right_nodes(left_node)
|
||||
|
||||
return []
|
||||
|
||||
|
||||
# Getting the left graph
|
||||
left_nodes = traverse_left_nodes(simplified_json_data[0])
|
||||
left_nodes.reverse()
|
||||
# print(node)
|
||||
# Print properties available on the first node
|
||||
node = left_nodes[0]
|
||||
# print(node["val"].keys())
|
||||
|
||||
scores = []
|
||||
for node in left_nodes:
|
||||
# print(node)
|
||||
# print(f'Node ID: {node["val"]["id"]}')
|
||||
# print(f'Node scores length: {len(node["val"]["node"]["scores"])}')
|
||||
if node["val"]["node"]:
|
||||
node_scores = node["val"]["node"]["scores"]
|
||||
if node_scores:
|
||||
for score in node_scores:
|
||||
scores.append(score)
|
||||
|
||||
# print(scores)
|
||||
|
||||
scores_values = [list(score_set.values()) for score_set in scores]
|
||||
|
||||
# Set up the figure for plotting on the same graph
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Generate a boxplot for each set of scores on the same graph
|
||||
boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))])
|
||||
|
||||
# Set figure name to node id
|
||||
ax.set_xscale('symlog', linthresh=1.0)
|
||||
|
||||
# Labeling
|
||||
ax.set_xlabel(f'Scores - Main Line')
|
||||
ax.set_ylabel('Score Sets')
|
||||
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
|
||||
|
||||
# Set y-axis labels to be visible
|
||||
ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))])
|
||||
|
||||
# Getting most recent right graph
|
||||
right_nodes = traverse_right_nodes(simplified_json_data[0])
|
||||
if len(right_nodes) != 0:
|
||||
target_node_id = None
|
||||
target_node = None
|
||||
if target_node_id:
|
||||
for node in right_nodes:
|
||||
if node["val"]["id"] == target_node_id:
|
||||
target_node = node
|
||||
break
|
||||
else:
|
||||
target_node = right_nodes[0]
|
||||
scores = target_node["val"]["node"]["scores"]
|
||||
|
||||
scores_values = [list(score_set.values()) for score_set in scores]
|
||||
|
||||
# Set up the figure for plotting on the same graph
|
||||
fig, ax = plt.subplots(figsize=(10, 6))
|
||||
|
||||
# Generate a boxplot for each set of scores on the same graph
|
||||
boxplots = ax.boxplot(scores_values, vert=False, patch_artist=True, labels=[f'Set {i+1}' for i in range(len(scores_values))])
|
||||
|
||||
ax.set_xscale('symlog', linthresh=1.0)
|
||||
|
||||
# Labeling
|
||||
ax.set_xlabel(f'Scores: {target_node['val']['id']}')
|
||||
ax.set_ylabel('Score Sets')
|
||||
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
|
||||
|
||||
# Set y-axis labels to be visible
|
||||
ax.set_yticklabels([f'Set {i+1}' for i in range(len(scores_values))])
|
||||
|
||||
# Find the highest scoring sets combining all scores and generations
|
||||
scores = []
|
||||
for node in left_nodes:
|
||||
if node["val"]["node"]:
|
||||
node_scores = node["val"]["node"]["scores"]
|
||||
translated_node_scores = []
|
||||
if node_scores:
|
||||
for i in range(len(node_scores)):
|
||||
for (individual, score) in node_scores[i].items():
|
||||
translated_node_scores.append((node["val"]["id"], i, score))
|
||||
|
||||
scores.append(translated_node_scores)
|
||||
|
||||
# Add scores from the right nodes
|
||||
if len(right_nodes) != 0:
|
||||
for node in right_nodes:
|
||||
if node["val"]["node"]:
|
||||
node_scores = node["val"]["node"]["scores"]
|
||||
translated_node_scores = []
|
||||
if node_scores:
|
||||
for i in range(len(node_scores)):
|
||||
for (individual, score) in node_scores[i].items():
|
||||
translated_node_scores.append((node["val"]["id"], i, score))
|
||||
scores.append(translated_node_scores)
|
||||
|
||||
# Organize scores by individual and then by generation
|
||||
individual_generation_scores = defaultdict(lambda: defaultdict(list))
|
||||
for sublist in scores:
|
||||
for id, generation, score in sublist:
|
||||
individual_generation_scores[id][generation].append(score)
|
||||
|
||||
# Calculate Q3 for each individual's generation
|
||||
individual_generation_q3 = {}
|
||||
for id, generations in individual_generation_scores.items():
|
||||
for gen, scores in generations.items():
|
||||
individual_generation_q3[(id, gen)] = np.percentile(scores, 75)
|
||||
|
||||
# Sort by Q3 value, highest first, and select the top 20
|
||||
top_20_individual_generations = sorted(individual_generation_q3, key=individual_generation_q3.get, reverse=True)[:40]
|
||||
|
||||
# Prepare scores for the top 20 for plotting
|
||||
top_20_scores = [individual_generation_scores[id][gen] for id, gen in top_20_individual_generations]
|
||||
|
||||
# Adjust labels for clarity, indicating both the individual ID and generation
|
||||
labels = [f'{id[:8]}... Gen {gen}' for id, gen in top_20_individual_generations]
|
||||
|
||||
# Generate box and whisker plots for the top 20 individual generations
|
||||
fig, ax = plt.subplots(figsize=(12, 10))
|
||||
ax.boxplot(top_20_scores, vert=False, patch_artist=True, labels=labels)
|
||||
|
||||
ax.set_xscale('symlog', linthresh=1.0)
|
||||
|
||||
ax.set_xlabel('Scores')
|
||||
ax.set_ylabel('Individual Generation')
|
||||
ax.set_title('Top 20 Individual Generations by Q3 Value')
|
||||
ax.yaxis.grid(True) # Add horizontal grid lines for clarity
|
||||
|
||||
# Display the plot
|
||||
plt.show()
|
||||
|
1
carp_spike/.gitignore
vendored
1
carp_spike/.gitignore
vendored
|
@ -1 +0,0 @@
|
|||
out/
|
|
@ -1,10 +0,0 @@
|
|||
(use Random)
|
||||
(Project.config "title" "gemla")
|
||||
|
||||
(deftype SimulationNode [population-size Int, population-cutoff Int])
|
||||
|
||||
;; (let [test (SimulationNode.init 10 3)]
|
||||
;; (do
|
||||
;; (SimulationNode.set-population-size test 20)
|
||||
;; (SimulationNode.population-size &test)
|
||||
;; ))
|
9
extract_fann_data/Cargo.toml
Normal file
9
extract_fann_data/Cargo.toml
Normal file
|
@ -0,0 +1,9 @@
|
|||
[package]
|
||||
name = "extract_fann_data"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
fann = "0.1.8"
|
11
extract_fann_data/build.rs
Normal file
11
extract_fann_data/build.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
fn main() {
|
||||
// Replace this with the path to the directory containing `fann.lib`
|
||||
let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib";
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", lib_dir);
|
||||
println!("cargo:rustc-link-lib=static=fann");
|
||||
// Use `dylib=fann` instead of `static=fann` if you're linking dynamically
|
||||
|
||||
// If there are any additional directories where the compiler can find header files, you can specify them like this:
|
||||
// println!("cargo:include={}", path_to_include_directory);
|
||||
}
|
38
extract_fann_data/src/main.rs
Normal file
38
extract_fann_data/src/main.rs
Normal file
|
@ -0,0 +1,38 @@
|
|||
extern crate fann;
|
||||
|
||||
use fann::Fann;
|
||||
use std::os::raw::c_uint;
|
||||
|
||||
fn main() {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
if args.len() < 2 {
|
||||
eprintln!("Usage: {} <network_file>", args[0]);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let network_file = &args[1];
|
||||
match Fann::from_file(network_file) {
|
||||
Ok(ann) => {
|
||||
// Output layer sizes
|
||||
let layer_sizes = ann.get_layer_sizes();
|
||||
let bias_counts = ann.get_bias_counts();
|
||||
|
||||
println!("Layers:");
|
||||
for (layer_size, bias_count) in layer_sizes.iter().zip(bias_counts.iter()) {
|
||||
println!("{} {}", layer_size, bias_count);
|
||||
}
|
||||
|
||||
// Output connections
|
||||
println!("Connections:");
|
||||
let connections = ann.get_connections();
|
||||
|
||||
for connection in connections {
|
||||
println!("{} {} {}", connection.from_neuron, connection.to_neuron, connection.weight);
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
eprintln!("Error loading network from file {}: {}", network_file, err);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -19,4 +19,7 @@ serde = { version = "1.0", features = ["derive"] }
|
|||
thiserror = "1.0"
|
||||
anyhow = "1.0"
|
||||
bincode = "1.3.3"
|
||||
log = "0.4.14"
|
||||
log = "0.4.14"
|
||||
serde_json = "1.0.114"
|
||||
tokio = { version = "1.37.0", features = ["full"] }
|
||||
futures = "0.3.30"
|
||||
|
|
5
file_linked/src/constants/data_format.rs
Normal file
5
file_linked/src/constants/data_format.rs
Normal file
|
@ -0,0 +1,5 @@
|
|||
#[derive(Debug)]
|
||||
pub enum DataFormat {
|
||||
Bincode,
|
||||
Json,
|
||||
}
|
1
file_linked/src/constants/mod.rs
Normal file
1
file_linked/src/constants/mod.rs
Normal file
|
@ -0,0 +1 @@
|
|||
pub mod data_format;
|
|
@ -1,8 +1,10 @@
|
|||
//! A wrapper around an object that ties it to a physical file
|
||||
|
||||
pub mod constants;
|
||||
pub mod error;
|
||||
|
||||
use anyhow::{anyhow, Context};
|
||||
use constants::data_format::DataFormat;
|
||||
use error::Error;
|
||||
use log::info;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
|
@ -10,9 +12,10 @@ use std::{
|
|||
fs::{copy, remove_file, File},
|
||||
io::{ErrorKind, Write},
|
||||
path::{Path, PathBuf},
|
||||
thread,
|
||||
thread::JoinHandle,
|
||||
sync::Arc,
|
||||
thread::{self, JoinHandle},
|
||||
};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// A wrapper around an object `T` that ties the object to a physical file
|
||||
#[derive(Debug)]
|
||||
|
@ -20,10 +23,11 @@ pub struct FileLinked<T>
|
|||
where
|
||||
T: Serialize,
|
||||
{
|
||||
val: T,
|
||||
val: Arc<RwLock<T>>,
|
||||
path: PathBuf,
|
||||
temp_file_path: PathBuf,
|
||||
file_thread: Option<JoinHandle<()>>,
|
||||
data_format: DataFormat,
|
||||
}
|
||||
|
||||
impl<T> Drop for FileLinked<T>
|
||||
|
@ -48,10 +52,12 @@ where
|
|||
/// # Examples
|
||||
/// ```
|
||||
/// # use file_linked::*;
|
||||
/// # use file_linked::constants::data_format::DataFormat;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -60,27 +66,30 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from("two"),
|
||||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// assert_eq!(linked_test.readonly().b, String::from("two"));
|
||||
/// assert_eq!(linked_test.readonly().c, 3.0);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
/// assert_eq!(readonly_ref.b, String::from("two"));
|
||||
/// assert_eq!(readonly_ref.c, 3.0);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn readonly(&self) -> &T {
|
||||
&self.val
|
||||
pub fn readonly(&self) -> Arc<RwLock<T>> {
|
||||
self.val.clone()
|
||||
}
|
||||
|
||||
/// Creates a new [`FileLinked`] object of type `T` stored to the file given by `path`.
|
||||
|
@ -88,10 +97,12 @@ where
|
|||
/// # Examples
|
||||
/// ```
|
||||
/// # use file_linked::*;
|
||||
/// # use file_linked::constants::data_format::DataFormat;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// #[derive(Deserialize, Serialize)]
|
||||
/// struct Test {
|
||||
|
@ -100,26 +111,29 @@ where
|
|||
/// pub c: f64
|
||||
/// }
|
||||
///
|
||||
/// # fn main() {
|
||||
/// #[tokio::main]
|
||||
/// # async fn main() {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from("two"),
|
||||
/// c: 3.0
|
||||
/// };
|
||||
///
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// let linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Json).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// assert_eq!(linked_test.readonly().b, String::from("two"));
|
||||
/// assert_eq!(linked_test.readonly().c, 3.0);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
/// assert_eq!(readonly_ref.b, String::from("two"));
|
||||
/// assert_eq!(readonly_ref.c, 3.0);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
/// # std::fs::remove_file("./temp").expect("Unable to remove file");
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn new(val: T, path: &Path) -> Result<FileLinked<T>, Error> {
|
||||
pub async fn new(val: T, path: &Path, data_format: DataFormat) -> Result<FileLinked<T>, Error> {
|
||||
let mut temp_file_path = path.to_path_buf();
|
||||
temp_file_path.set_file_name(format!(
|
||||
".temp{}",
|
||||
|
@ -130,21 +144,28 @@ where
|
|||
));
|
||||
|
||||
let mut result = FileLinked {
|
||||
val,
|
||||
val: Arc::new(RwLock::new(val)),
|
||||
path: path.to_path_buf(),
|
||||
temp_file_path,
|
||||
file_thread: None,
|
||||
data_format,
|
||||
};
|
||||
|
||||
result.write_data()?;
|
||||
result.write_data().await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn write_data(&mut self) -> Result<(), Error> {
|
||||
async fn write_data(&mut self) -> Result<(), Error> {
|
||||
let thread_path = self.path.clone();
|
||||
let thread_temp_path = self.temp_file_path.clone();
|
||||
let thread_val = bincode::serialize(&self.val)
|
||||
.with_context(|| "Unable to serialize object into bincode".to_string())?;
|
||||
let val = self.val.read().await;
|
||||
|
||||
let thread_val = match self.data_format {
|
||||
DataFormat::Bincode => bincode::serialize(&*val)
|
||||
.with_context(|| "Unable to serialize object into bincode".to_string())?,
|
||||
DataFormat::Json => serde_json::to_vec(&*val)
|
||||
.with_context(|| "Unable to serialize object into JSON".to_string())?,
|
||||
};
|
||||
|
||||
if let Some(file_thread) = self.file_thread.take() {
|
||||
file_thread
|
||||
|
@ -190,10 +211,12 @@ where
|
|||
/// ```
|
||||
/// # use file_linked::*;
|
||||
/// # use file_linked::error::Error;
|
||||
/// # use file_linked::constants::data_format::DataFormat;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -202,21 +225,28 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from(""),
|
||||
/// c: 0.0
|
||||
/// };
|
||||
///
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// {
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
/// }
|
||||
///
|
||||
/// linked_test.mutate(|t| t.a = 2)?;
|
||||
/// linked_test.mutate(|t| t.a = 2).await?;
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 2);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 2);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -225,10 +255,15 @@ where
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn mutate<U, F: FnOnce(&mut T) -> U>(&mut self, op: F) -> Result<U, Error> {
|
||||
let result = op(&mut self.val);
|
||||
pub async fn mutate<U, F: FnOnce(&mut T) -> U>(&mut self, op: F) -> Result<U, Error> {
|
||||
let val_clone = self.val.clone(); // Arc<RwLock<T>>
|
||||
let mut val = val_clone.write().await; // RwLockWriteGuard<T>
|
||||
|
||||
self.write_data()?;
|
||||
let result = op(&mut val);
|
||||
|
||||
drop(val);
|
||||
|
||||
self.write_data().await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -239,10 +274,12 @@ where
|
|||
/// ```
|
||||
/// # use file_linked::*;
|
||||
/// # use file_linked::error::Error;
|
||||
/// # use file_linked::constants::data_format::DataFormat;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -251,25 +288,30 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from(""),
|
||||
/// c: 0.0
|
||||
/// };
|
||||
///
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"))
|
||||
/// let mut linked_test = FileLinked::new(test, &PathBuf::from("./temp"), DataFormat::Bincode).await
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 1);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 1);
|
||||
///
|
||||
/// linked_test.replace(Test {
|
||||
/// a: 2,
|
||||
/// b: String::from(""),
|
||||
/// c: 0.0
|
||||
/// })?;
|
||||
/// }).await?;
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, 2);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, 2);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -278,10 +320,30 @@ where
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn replace(&mut self, val: T) -> Result<(), Error> {
|
||||
self.val = val;
|
||||
pub async fn replace(&mut self, val: T) -> Result<(), Error> {
|
||||
self.val = Arc::new(RwLock::new(val));
|
||||
|
||||
self.write_data()
|
||||
self.write_data().await
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> FileLinked<T>
|
||||
where
|
||||
T: Serialize + DeserializeOwned + Send + 'static,
|
||||
{
|
||||
/// Asynchronously modifies the data contained in a `FileLinked` object using an async callback `op`.
|
||||
pub async fn mutate_async<F, Fut, U>(&mut self, op: F) -> Result<U, Error>
|
||||
where
|
||||
F: FnOnce(Arc<RwLock<T>>) -> Fut,
|
||||
Fut: std::future::Future<Output = U> + Send,
|
||||
U: Send,
|
||||
{
|
||||
let val_clone = self.val.clone();
|
||||
let result = op(val_clone).await;
|
||||
|
||||
self.write_data().await?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -295,6 +357,7 @@ where
|
|||
/// ```
|
||||
/// # use file_linked::*;
|
||||
/// # use file_linked::error::Error;
|
||||
/// # use file_linked::constants::data_format::DataFormat;
|
||||
/// # use serde::{Deserialize, Serialize};
|
||||
/// # use std::fmt;
|
||||
/// # use std::string::ToString;
|
||||
|
@ -302,6 +365,7 @@ where
|
|||
/// # use std::fs::OpenOptions;
|
||||
/// # use std::io::Write;
|
||||
/// # use std::path::PathBuf;
|
||||
/// # use tokio;
|
||||
/// #
|
||||
/// # #[derive(Deserialize, Serialize)]
|
||||
/// # struct Test {
|
||||
|
@ -310,7 +374,8 @@ where
|
|||
/// # pub c: f64
|
||||
/// # }
|
||||
/// #
|
||||
/// # fn main() -> Result<(), Error> {
|
||||
/// # #[tokio::main]
|
||||
/// # async fn main() -> Result<(), Error> {
|
||||
/// let test = Test {
|
||||
/// a: 1,
|
||||
/// b: String::from("2"),
|
||||
|
@ -327,12 +392,14 @@ where
|
|||
///
|
||||
/// bincode::serialize_into(file, &test).expect("Unable to serialize object");
|
||||
///
|
||||
/// let mut linked_test = FileLinked::<Test>::from_file(&path)
|
||||
/// let mut linked_test = FileLinked::<Test>::from_file(&path, DataFormat::Bincode)
|
||||
/// .expect("Unable to create file linked object");
|
||||
///
|
||||
/// assert_eq!(linked_test.readonly().a, test.a);
|
||||
/// assert_eq!(linked_test.readonly().b, test.b);
|
||||
/// assert_eq!(linked_test.readonly().c, test.c);
|
||||
/// let readonly = linked_test.readonly();
|
||||
/// let readonly_ref = readonly.read().await;
|
||||
/// assert_eq!(readonly_ref.a, test.a);
|
||||
/// assert_eq!(readonly_ref.b, test.b);
|
||||
/// assert_eq!(readonly_ref.c, test.c);
|
||||
/// #
|
||||
/// # drop(linked_test);
|
||||
/// #
|
||||
|
@ -341,7 +408,7 @@ where
|
|||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub fn from_file(path: &Path) -> Result<FileLinked<T>, Error> {
|
||||
pub fn from_file(path: &Path, data_format: DataFormat) -> Result<FileLinked<T>, Error> {
|
||||
let mut temp_file_path = path.to_path_buf();
|
||||
temp_file_path.set_file_name(format!(
|
||||
".temp{}",
|
||||
|
@ -351,16 +418,22 @@ where
|
|||
.ok_or_else(|| anyhow!("Unable to get filename for tempfile {}", path.display()))?
|
||||
));
|
||||
|
||||
match File::open(path).map_err(Error::from).and_then(|file| {
|
||||
bincode::deserialize_from::<File, T>(file)
|
||||
.with_context(|| format!("Unable to deserialize file {}", path.display()))
|
||||
.map_err(Error::from)
|
||||
}) {
|
||||
match File::open(path)
|
||||
.map_err(Error::from)
|
||||
.and_then(|file| match data_format {
|
||||
DataFormat::Bincode => bincode::deserialize_from::<File, T>(file)
|
||||
.with_context(|| format!("Unable to deserialize file {}", path.display()))
|
||||
.map_err(Error::from),
|
||||
DataFormat::Json => serde_json::from_reader(file)
|
||||
.with_context(|| format!("Unable to deserialize file {}", path.display()))
|
||||
.map_err(Error::from),
|
||||
}) {
|
||||
Ok(val) => Ok(FileLinked {
|
||||
val,
|
||||
val: Arc::new(RwLock::new(val)),
|
||||
path: path.to_path_buf(),
|
||||
temp_file_path,
|
||||
file_thread: None,
|
||||
data_format,
|
||||
}),
|
||||
Err(err) => {
|
||||
info!(
|
||||
|
@ -370,30 +443,43 @@ where
|
|||
);
|
||||
|
||||
// Try to use temp file instead and see if that file exists and is serializable
|
||||
let val = FileLinked::from_temp_file(&temp_file_path, path)
|
||||
let val = FileLinked::from_temp_file(&temp_file_path, path, &data_format)
|
||||
.map_err(|_| err)
|
||||
.with_context(|| format!("Failed to read/deserialize the object from the file {} and temp file {}", path.display(), temp_file_path.display()))?;
|
||||
|
||||
Ok(FileLinked {
|
||||
val,
|
||||
val: Arc::new(RwLock::new(val)),
|
||||
path: path.to_path_buf(),
|
||||
temp_file_path,
|
||||
file_thread: None,
|
||||
data_format,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn from_temp_file(temp_file_path: &Path, path: &Path) -> Result<T, Error> {
|
||||
fn from_temp_file(
|
||||
temp_file_path: &Path,
|
||||
path: &Path,
|
||||
data_format: &DataFormat,
|
||||
) -> Result<T, Error> {
|
||||
let file = File::open(temp_file_path)
|
||||
.with_context(|| format!("Unable to open file {}", temp_file_path.display()))?;
|
||||
|
||||
let val = bincode::deserialize_from(file).with_context(|| {
|
||||
format!(
|
||||
"Could not deserialize from temp file {}",
|
||||
temp_file_path.display()
|
||||
)
|
||||
})?;
|
||||
let val = match data_format {
|
||||
DataFormat::Bincode => bincode::deserialize_from(file).with_context(|| {
|
||||
format!(
|
||||
"Could not deserialize from temp file {}",
|
||||
temp_file_path.display()
|
||||
)
|
||||
})?,
|
||||
DataFormat::Json => serde_json::from_reader(file).with_context(|| {
|
||||
format!(
|
||||
"Could not deserialize from temp file {}",
|
||||
temp_file_path.display()
|
||||
)
|
||||
})?,
|
||||
};
|
||||
|
||||
info!("Successfully deserialized value from temp file");
|
||||
|
||||
|
@ -421,8 +507,12 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn run<F: FnOnce(&Path) -> Result<(), Error>>(&self, op: F) -> Result<(), Error> {
|
||||
op(&self.path)
|
||||
pub async fn run<F, Fut>(&self, op: F) -> ()
|
||||
where
|
||||
F: FnOnce(PathBuf) -> Fut,
|
||||
Fut: std::future::Future<Output = ()>,
|
||||
{
|
||||
op(self.path.clone()).await
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -434,92 +524,173 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_readonly() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_readonly() {
|
||||
let path = PathBuf::from("test_readonly");
|
||||
let cleanup = CleanUp::new(&path);
|
||||
cleanup.run(|p| {
|
||||
let val = vec!["one", "two", ""];
|
||||
cleanup
|
||||
.run(|p| async move {
|
||||
let val = vec!["one", "two", ""];
|
||||
|
||||
let linked_object = FileLinked::new(val.clone(), &p)?;
|
||||
assert_eq!(*linked_object.readonly(), val);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
let linked_object = FileLinked::new(val.clone(), &p, DataFormat::Json)
|
||||
.await
|
||||
.expect("Unable to create file linked object");
|
||||
let linked_object_arc = linked_object.readonly();
|
||||
let linked_object_ref = linked_object_arc.read().await;
|
||||
assert_eq!(*linked_object_ref, val);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_new() {
|
||||
let path = PathBuf::from("test_new");
|
||||
let cleanup = CleanUp::new(&path);
|
||||
cleanup.run(|p| {
|
||||
let val = "test";
|
||||
cleanup
|
||||
.run(|p| async move {
|
||||
let val = "test";
|
||||
|
||||
FileLinked::new(val, &p)?;
|
||||
FileLinked::new(val, &p, DataFormat::Bincode)
|
||||
.await
|
||||
.expect("Unable to create file linked object");
|
||||
|
||||
let file = File::open(&p)?;
|
||||
let result: String =
|
||||
bincode::deserialize_from(file).expect("Unable to deserialize from file");
|
||||
assert_eq!(result, val);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
let file = File::open(&p).expect("Unable to open file");
|
||||
let result: String =
|
||||
bincode::deserialize_from(file).expect("Unable to deserialize from file");
|
||||
assert_eq!(result, val);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutate() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_mutate() {
|
||||
let path = PathBuf::from("test_mutate");
|
||||
let cleanup = CleanUp::new(&path);
|
||||
cleanup.run(|p| {
|
||||
let list = vec![1, 2, 3, 4];
|
||||
let mut file_linked_list = FileLinked::new(list, &p)?;
|
||||
assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4]);
|
||||
cleanup
|
||||
.run(|p| async move {
|
||||
let list = vec![1, 2, 3, 4];
|
||||
let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json)
|
||||
.await
|
||||
.expect("Unable to create file linked object");
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
file_linked_list.mutate(|v1| v1.push(5))?;
|
||||
assert_eq!(*file_linked_list.readonly(), vec![1, 2, 3, 4, 5]);
|
||||
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]);
|
||||
|
||||
file_linked_list.mutate(|v1| v1[1] = 1)?;
|
||||
assert_eq!(*file_linked_list.readonly(), vec![1, 1, 3, 4, 5]);
|
||||
drop(file_linked_list_ref);
|
||||
file_linked_list
|
||||
.mutate(|v1| v1.push(5))
|
||||
.await
|
||||
.expect("Error mutating file linked object");
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
drop(file_linked_list);
|
||||
Ok(())
|
||||
})
|
||||
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4, 5]);
|
||||
|
||||
drop(file_linked_list_ref);
|
||||
file_linked_list
|
||||
.mutate(|v1| v1[1] = 1)
|
||||
.await
|
||||
.expect("Error mutating file linked object");
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]);
|
||||
|
||||
drop(file_linked_list);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_replace() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_async_mutate() {
|
||||
let path = PathBuf::from("test_async_mutate");
|
||||
let cleanup = CleanUp::new(&path);
|
||||
cleanup
|
||||
.run(|p| async move {
|
||||
let list = vec![1, 2, 3, 4];
|
||||
let mut file_linked_list = FileLinked::new(list, &p, DataFormat::Json)
|
||||
.await
|
||||
.expect("Unable to create file linked object");
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
assert_eq!(*file_linked_list_ref, vec![1, 2, 3, 4]);
|
||||
|
||||
drop(file_linked_list_ref);
|
||||
file_linked_list
|
||||
.mutate_async(|v1| async move {
|
||||
let mut v = v1.write().await;
|
||||
v.push(5);
|
||||
v[1] = 1;
|
||||
Ok::<(), Error>(())
|
||||
})
|
||||
.await
|
||||
.expect("Error mutating file linked object")
|
||||
.expect("Error mutating file linked object");
|
||||
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
assert_eq!(*file_linked_list_ref, vec![1, 1, 3, 4, 5]);
|
||||
|
||||
drop(file_linked_list);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_replace() {
|
||||
let path = PathBuf::from("test_replace");
|
||||
let cleanup = CleanUp::new(&path);
|
||||
cleanup.run(|p| {
|
||||
let val1 = String::from("val1");
|
||||
let val2 = String::from("val2");
|
||||
let mut file_linked_list = FileLinked::new(val1.clone(), &p)?;
|
||||
assert_eq!(*file_linked_list.readonly(), val1);
|
||||
cleanup
|
||||
.run(|p| async move {
|
||||
let val1 = String::from("val1");
|
||||
let val2 = String::from("val2");
|
||||
let mut file_linked_list = FileLinked::new(val1.clone(), &p, DataFormat::Bincode)
|
||||
.await
|
||||
.expect("Unable to create file linked object");
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
file_linked_list.replace(val2.clone())?;
|
||||
assert_eq!(*file_linked_list.readonly(), val2);
|
||||
assert_eq!(*file_linked_list_ref, val1);
|
||||
|
||||
drop(file_linked_list);
|
||||
Ok(())
|
||||
})
|
||||
file_linked_list
|
||||
.replace(val2.clone())
|
||||
.await
|
||||
.expect("Error replacing file linked object");
|
||||
let file_linked_list_arc = file_linked_list.readonly();
|
||||
let file_linked_list_ref = file_linked_list_arc.read().await;
|
||||
|
||||
assert_eq!(*file_linked_list_ref, val2);
|
||||
|
||||
drop(file_linked_list);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_from_file() {
|
||||
let path = PathBuf::from("test_from_file");
|
||||
let cleanup = CleanUp::new(&path);
|
||||
cleanup.run(|p| {
|
||||
let value: Vec<f64> = vec![2.0, 3.0, 5.0];
|
||||
let file = File::create(&p)?;
|
||||
cleanup
|
||||
.run(|p| async move {
|
||||
let value: Vec<f64> = vec![2.0, 3.0, 5.0];
|
||||
let file = File::create(&p).expect("Unable to create file");
|
||||
|
||||
bincode::serialize_into(&file, &value).expect("Unable to serialize into file");
|
||||
drop(file);
|
||||
bincode::serialize_into(&file, &value).expect("Unable to serialize into file");
|
||||
drop(file);
|
||||
|
||||
let linked_object: FileLinked<Vec<f64>> = FileLinked::from_file(&p)?;
|
||||
assert_eq!(*linked_object.readonly(), value);
|
||||
let linked_object: FileLinked<Vec<f64>> =
|
||||
FileLinked::from_file(&p, DataFormat::Bincode)
|
||||
.expect("Unable to create file linked object");
|
||||
let linked_object_arc = linked_object.readonly();
|
||||
let linked_object_ref = linked_object_arc.read().await;
|
||||
|
||||
drop(linked_object);
|
||||
Ok(())
|
||||
})
|
||||
assert_eq!(*linked_object_ref, value);
|
||||
|
||||
drop(linked_object);
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -15,18 +15,22 @@ categories = ["simulation"]
|
|||
[dependencies]
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
uuid = { version = "0.8", features = ["serde", "v4"] }
|
||||
clap = { version = "~2.27.0", features = ["yaml"] }
|
||||
toml = "0.5.8"
|
||||
uuid = { version = "1.7", features = ["serde", "v4"] }
|
||||
clap = { version = "4.5.2", features = ["derive"] }
|
||||
toml = "0.8.10"
|
||||
regex = "1"
|
||||
file_linked = { version = "0.1.0", path = "../file_linked" }
|
||||
thiserror = "1.0"
|
||||
anyhow = "1.0"
|
||||
rand = "0.8.4"
|
||||
log = "0.4.14"
|
||||
env_logger = "0.9.0"
|
||||
futures = "0.3.17"
|
||||
smol = "1.2.5"
|
||||
smol-potat = "1.1.2"
|
||||
num_cpus = "1.13.0"
|
||||
easy-parallel = "3.1.0"
|
||||
rand = "0.8.5"
|
||||
log = "0.4.21"
|
||||
env_logger = "0.11.3"
|
||||
futures = "0.3.30"
|
||||
tokio = { version = "1.37.0", features = ["full"] }
|
||||
num_cpus = "1.16.0"
|
||||
easy-parallel = "3.3.1"
|
||||
fann = "0.1.8"
|
||||
async-trait = "0.1.78"
|
||||
async-recursion = "1.1.0"
|
||||
lerp = "0.5.0"
|
||||
console-subscriber = "0.2.0"
|
||||
|
|
11
gemla/build.rs
Normal file
11
gemla/build.rs
Normal file
|
@ -0,0 +1,11 @@
|
|||
fn main() {
|
||||
// Replace this with the path to the directory containing `fann.lib`
|
||||
let lib_dir = "F://vandomej/Downloads/vcpkg/packages/fann_x64-windows/lib";
|
||||
|
||||
println!("cargo:rustc-link-search=native={}", lib_dir);
|
||||
println!("cargo:rustc-link-lib=static=fann");
|
||||
// Use `dylib=fann` instead of `static=fann` if you're linking dynamically
|
||||
|
||||
// If there are any additional directories where the compiler can find header files, you can specify them like this:
|
||||
// println!("cargo:include={}", path_to_include_directory);
|
||||
}
|
|
@ -1,9 +0,0 @@
|
|||
name: GEMLA
|
||||
version: "0.1"
|
||||
autor: Jacob VanDomelen <jacob.vandome15@gmail.com>
|
||||
about: Uses a genetic algorithm to generate a machine learning algorithm.
|
||||
args:
|
||||
- FILE:
|
||||
help: Sets the input/output file for the program.
|
||||
required: true
|
||||
index: 1
|
|
@ -1,15 +0,0 @@
|
|||
[[nodes]]
|
||||
fabric_addr = "10.0.0.1:9999"
|
||||
bridge_bind = "10.0.0.1:8888"
|
||||
mem = "100 GiB"
|
||||
cpu = 8
|
||||
|
||||
# [[nodes]]
|
||||
# fabric_addr = "10.0.0.2:9999"
|
||||
# mem = "100 GiB"
|
||||
# cpu = 16
|
||||
|
||||
# [[nodes]]
|
||||
# fabric_addr = "10.0.0.3:9999"
|
||||
# mem = "100 GiB"
|
||||
# cpu = 16
|
|
@ -1,74 +1,72 @@
|
|||
#[macro_use]
|
||||
extern crate clap;
|
||||
extern crate gemla;
|
||||
#[macro_use]
|
||||
extern crate log;
|
||||
|
||||
mod fighter_nn;
|
||||
mod test_state;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use clap::App;
|
||||
use easy_parallel::Parallel;
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use fighter_nn::FighterNN;
|
||||
use file_linked::constants::data_format::DataFormat;
|
||||
use gemla::{
|
||||
constants::args::FILE,
|
||||
core::{Gemla, GemlaConfig},
|
||||
error::{log_error, Error},
|
||||
error::log_error,
|
||||
};
|
||||
use smol::{channel, channel::RecvError, future, Executor};
|
||||
use std::{path::PathBuf, time::Instant};
|
||||
use test_state::TestState;
|
||||
|
||||
// const NUM_THREADS: usize = 2;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The file to read/write the dataset from/to.
|
||||
#[arg(short, long)]
|
||||
file: String,
|
||||
}
|
||||
|
||||
/// Runs a simluation of a genetic algorithm against a dataset.
|
||||
///
|
||||
/// Use the -h, --h, or --help flag to see usage syntax.
|
||||
/// TODO
|
||||
fn main() -> anyhow::Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
info!("Starting");
|
||||
// console_subscriber::init();
|
||||
|
||||
info!("Starting");
|
||||
let now = Instant::now();
|
||||
|
||||
// Obtainning number of threads to use
|
||||
let num_threads = num_cpus::get().max(1);
|
||||
let ex = Executor::new();
|
||||
let (signal, shutdown) = channel::unbounded::<()>();
|
||||
// Manually configure the Tokio runtime
|
||||
let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(num_cpus::get())
|
||||
// .worker_threads(NUM_THREADS)
|
||||
.build()?
|
||||
.block_on(async {
|
||||
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
|
||||
let mut gemla = log_error(
|
||||
Gemla::<FighterNN>::new(
|
||||
&PathBuf::from(args.file),
|
||||
GemlaConfig { overwrite: false },
|
||||
DataFormat::Json,
|
||||
)
|
||||
.await,
|
||||
)?;
|
||||
|
||||
// Create an executor thread pool.
|
||||
let (_, result): (Vec<Result<(), RecvError>>, Result<(), Error>) = Parallel::new()
|
||||
.each(0..num_threads, |_| {
|
||||
future::block_on(ex.run(shutdown.recv()))
|
||||
})
|
||||
.finish(|| {
|
||||
smol::block_on(async {
|
||||
drop(signal);
|
||||
// let gemla_arc = Arc::new(gemla);
|
||||
|
||||
// Command line arguments are parsed with the clap crate. And this program uses
|
||||
// the yaml method with clap.
|
||||
let yaml = load_yaml!("../../cli.yml");
|
||||
let matches = App::from_yaml(yaml).get_matches();
|
||||
// Setup your application logic here
|
||||
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
|
||||
|
||||
// Checking that the first argument <FILE> is a valid file
|
||||
if let Some(file_path) = matches.value_of(FILE) {
|
||||
let mut gemla = log_error(Gemla::<TestState>::new(
|
||||
&PathBuf::from(file_path),
|
||||
GemlaConfig {
|
||||
generations_per_node: 3,
|
||||
overwrite: true,
|
||||
},
|
||||
))?;
|
||||
|
||||
log_error(gemla.simulate(3).await)?;
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::Other(anyhow!("Invalid argument for FILE")))
|
||||
}
|
||||
})
|
||||
// Example placeholder loop to continuously run simulate
|
||||
loop {
|
||||
// Arbitrary loop count for demonstration
|
||||
gemla.simulate(1).await?;
|
||||
}
|
||||
});
|
||||
|
||||
result?;
|
||||
runtime?; // Handle errors from the block_on call
|
||||
|
||||
info!("Finished in {:?}", now.elapsed());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
79
gemla/src/bin/fighter_nn/fighter_context.rs
Normal file
79
gemla/src/bin/fighter_nn/fighter_context.rs
Normal file
|
@ -0,0 +1,79 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use serde::ser::SerializeTuple;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
const SHARED_SEMAPHORE_CONCURRENCY_LIMIT: usize = 50;
|
||||
const VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT: usize = 1;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FighterContext {
|
||||
pub shared_semaphore: Arc<Semaphore>,
|
||||
pub visible_simulations: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl Default for FighterContext {
|
||||
fn default() -> Self {
|
||||
FighterContext {
|
||||
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
|
||||
visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Custom serialization to just output the concurrency limit.
|
||||
impl Serialize for FighterContext {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
// Assuming the semaphore's available permits represent the concurrency limit.
|
||||
// This part is tricky since Semaphore does not expose its initial permits.
|
||||
// You might need to store the concurrency limit as a separate field if this assumption doesn't hold.
|
||||
let concurrency_limit = SHARED_SEMAPHORE_CONCURRENCY_LIMIT;
|
||||
let visible_concurrency_limit = VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT;
|
||||
// serializer.serialize_u64(concurrency_limit as u64)
|
||||
|
||||
// Serialize the concurrency limit as a tuple
|
||||
let mut state = serializer.serialize_tuple(2)?;
|
||||
state.serialize_element(&concurrency_limit)?;
|
||||
state.serialize_element(&visible_concurrency_limit)?;
|
||||
state.end()
|
||||
}
|
||||
}
|
||||
|
||||
// Custom deserialization to reconstruct the FighterContext from a concurrency limit.
|
||||
impl<'de> Deserialize<'de> for FighterContext {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
// Deserialize the tuple
|
||||
let (_, _) = <(usize, usize)>::deserialize(deserializer)?;
|
||||
Ok(FighterContext {
|
||||
shared_semaphore: Arc::new(Semaphore::new(SHARED_SEMAPHORE_CONCURRENCY_LIMIT)),
|
||||
visible_simulations: Arc::new(Semaphore::new(VISIBLE_SIMULATIONS_CONCURRENCY_LIMIT)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_serialization() {
|
||||
let context = FighterContext::default();
|
||||
let serialized = serde_json::to_string(&context).unwrap();
|
||||
let deserialized: FighterContext = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(
|
||||
context.shared_semaphore.available_permits(),
|
||||
deserialized.shared_semaphore.available_permits()
|
||||
);
|
||||
assert_eq!(
|
||||
context.visible_simulations.available_permits(),
|
||||
deserialized.visible_simulations.available_permits()
|
||||
);
|
||||
}
|
||||
}
|
1631
gemla/src/bin/fighter_nn/mod.rs
Normal file
1631
gemla/src/bin/fighter_nn/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
1825
gemla/src/bin/fighter_nn/neural_network_utility.rs
Normal file
1825
gemla/src/bin/fighter_nn/neural_network_utility.rs
Normal file
File diff suppressed because it is too large
Load diff
|
@ -1,6 +1,11 @@
|
|||
use gemla::{core::genetic_node::GeneticNode, error::Error};
|
||||
use async_trait::async_trait;
|
||||
use gemla::{
|
||||
core::genetic_node::{GeneticNode, GeneticNodeContext},
|
||||
error::Error,
|
||||
};
|
||||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
const POPULATION_SIZE: u64 = 5;
|
||||
const POPULATION_REDUCTION_SIZE: u64 = 3;
|
||||
|
@ -8,20 +13,30 @@ const POPULATION_REDUCTION_SIZE: u64 = 3;
|
|||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct TestState {
|
||||
pub population: Vec<i64>,
|
||||
pub max_generations: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GeneticNode for TestState {
|
||||
fn initialize() -> Result<Box<Self>, Error> {
|
||||
type Context = ();
|
||||
|
||||
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error> {
|
||||
let mut population: Vec<i64> = vec![];
|
||||
|
||||
for _ in 0..POPULATION_SIZE {
|
||||
population.push(thread_rng().gen_range(0..100))
|
||||
}
|
||||
|
||||
Ok(Box::new(TestState { population }))
|
||||
Ok(Box::new(TestState {
|
||||
population,
|
||||
max_generations: 10,
|
||||
}))
|
||||
}
|
||||
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
async fn simulate(
|
||||
&mut self,
|
||||
context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<bool, Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
self.population = self
|
||||
|
@ -30,10 +45,14 @@ impl GeneticNode for TestState {
|
|||
.map(|p| p.saturating_add(rng.gen_range(-1..2)))
|
||||
.collect();
|
||||
|
||||
Ok(())
|
||||
if context.generation >= self.max_generations {
|
||||
Ok(false)
|
||||
} else {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let mut v = self.population.clone();
|
||||
|
@ -71,7 +90,12 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(
|
||||
left: &TestState,
|
||||
right: &TestState,
|
||||
id: &Uuid,
|
||||
gemla_context: Self::Context,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
let mut v = left.population.clone();
|
||||
v.append(&mut right.population.clone());
|
||||
|
||||
|
@ -80,9 +104,18 @@ impl GeneticNode for TestState {
|
|||
|
||||
v = v[..(POPULATION_REDUCTION_SIZE as usize)].to_vec();
|
||||
|
||||
let mut result = TestState { population: v };
|
||||
let mut result = TestState {
|
||||
population: v,
|
||||
max_generations: 10,
|
||||
};
|
||||
|
||||
result.mutate()?;
|
||||
result
|
||||
.mutate(GeneticNodeContext {
|
||||
id: *id,
|
||||
generation: 0,
|
||||
gemla_context,
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(result))
|
||||
}
|
||||
|
@ -93,57 +126,97 @@ mod tests {
|
|||
use super::*;
|
||||
use gemla::core::genetic_node::GeneticNode;
|
||||
|
||||
#[test]
|
||||
fn test_initialize() {
|
||||
let state = TestState::initialize().unwrap();
|
||||
#[tokio::test]
|
||||
async fn test_initialize() {
|
||||
let state = TestState::initialize(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
gemla_context: (),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simulate() {
|
||||
#[tokio::test]
|
||||
async fn test_simulate() {
|
||||
let mut state = TestState {
|
||||
population: vec![1, 1, 2, 3],
|
||||
max_generations: 1,
|
||||
};
|
||||
|
||||
let original_population = state.population.clone();
|
||||
|
||||
state.simulate().unwrap();
|
||||
state
|
||||
.simulate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
gemla_context: (),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
|
||||
|
||||
state.simulate().unwrap();
|
||||
state.simulate().unwrap();
|
||||
state
|
||||
.simulate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
gemla_context: (),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
state
|
||||
.simulate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
gemla_context: (),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
.all(|(&a, &b)| b >= a - 3 && b <= a + 6))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mutate() {
|
||||
#[tokio::test]
|
||||
async fn test_mutate() {
|
||||
let mut state = TestState {
|
||||
population: vec![4, 3, 3],
|
||||
max_generations: 1,
|
||||
};
|
||||
|
||||
state.mutate().unwrap();
|
||||
state
|
||||
.mutate(GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
gemla_context: (),
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge() {
|
||||
#[tokio::test]
|
||||
async fn test_merge() {
|
||||
let state1 = TestState {
|
||||
population: vec![1, 2, 4, 5],
|
||||
max_generations: 1,
|
||||
};
|
||||
|
||||
let state2 = TestState {
|
||||
population: vec![0, 1, 3, 7],
|
||||
max_generations: 1,
|
||||
};
|
||||
|
||||
let merged_state = TestState::merge(&state1, &state2).unwrap();
|
||||
let merged_state = TestState::merge(&state1, &state2, &Uuid::new_v4(), ())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(merged_state.population.len(), POPULATION_SIZE as usize);
|
||||
assert!(merged_state.population.iter().any(|&x| x == 7));
|
||||
|
|
|
@ -1,2 +0,0 @@
|
|||
/// Corresponds to the FILE command line argument used in accordance with the clap crate.
|
||||
pub const FILE: &str = "FILE";
|
|
@ -1 +0,0 @@
|
|||
pub mod args;
|
|
@ -5,7 +5,9 @@
|
|||
use crate::error::Error;
|
||||
|
||||
use anyhow::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use async_trait::async_trait;
|
||||
use log::info;
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use uuid::Uuid;
|
||||
|
||||
|
@ -24,45 +26,65 @@ pub enum GeneticState {
|
|||
Finish,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GeneticNodeContext<S> {
|
||||
pub generation: u64,
|
||||
pub id: Uuid,
|
||||
pub gemla_context: S,
|
||||
}
|
||||
|
||||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
///
|
||||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
pub trait GeneticNode {
|
||||
#[async_trait]
|
||||
pub trait GeneticNode: Send {
|
||||
type Context;
|
||||
|
||||
/// Initializes a new instance of a [`GeneticState`].
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn initialize() -> Result<Box<Self>, Error>;
|
||||
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
|
||||
|
||||
fn simulate(&mut self) -> Result<(), Error>;
|
||||
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>)
|
||||
-> Result<bool, Error>;
|
||||
|
||||
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn mutate(&mut self) -> Result<(), Error>;
|
||||
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
|
||||
|
||||
fn merge(left: &Self, right: &Self) -> Result<Box<Self>, Error>;
|
||||
async fn merge(
|
||||
left: &Self,
|
||||
right: &Self,
|
||||
id: &Uuid,
|
||||
context: Self::Context,
|
||||
) -> Result<Box<Self>, Error>;
|
||||
}
|
||||
|
||||
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
|
||||
/// well as signal recovery. Transition states are given by [`GeneticState`]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
|
||||
pub struct GeneticNodeWrapper<T> {
|
||||
pub struct GeneticNodeWrapper<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
node: Option<T>,
|
||||
state: GeneticState,
|
||||
generation: u64,
|
||||
max_generations: u64,
|
||||
id: Uuid,
|
||||
}
|
||||
|
||||
impl<T> Default for GeneticNodeWrapper<T> {
|
||||
impl<T> Default for GeneticNodeWrapper<T>
|
||||
where
|
||||
T: Clone,
|
||||
{
|
||||
fn default() -> Self {
|
||||
GeneticNodeWrapper {
|
||||
node: None,
|
||||
state: GeneticState::Initialize,
|
||||
generation: 1,
|
||||
max_generations: 1,
|
||||
id: Uuid::new_v4(),
|
||||
}
|
||||
}
|
||||
|
@ -70,21 +92,20 @@ impl<T> Default for GeneticNodeWrapper<T> {
|
|||
|
||||
impl<T> GeneticNodeWrapper<T>
|
||||
where
|
||||
T: GeneticNode + Debug,
|
||||
T: GeneticNode + Debug + Send + Clone,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub fn new(max_generations: u64) -> Self {
|
||||
pub fn new() -> Self {
|
||||
GeneticNodeWrapper::<T> {
|
||||
max_generations,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from(data: T, max_generations: u64, id: Uuid) -> Self {
|
||||
pub fn from(data: T, id: Uuid) -> Self {
|
||||
GeneticNodeWrapper {
|
||||
node: Some(data),
|
||||
state: GeneticState::Simulate,
|
||||
generation: 1,
|
||||
max_generations,
|
||||
id,
|
||||
}
|
||||
}
|
||||
|
@ -93,36 +114,51 @@ where
|
|||
self.node.as_ref()
|
||||
}
|
||||
|
||||
pub fn take(&mut self) -> Option<T> {
|
||||
self.node.take()
|
||||
}
|
||||
|
||||
pub fn id(&self) -> Uuid {
|
||||
self.id
|
||||
}
|
||||
|
||||
pub fn max_generations(&self) -> u64 {
|
||||
self.max_generations
|
||||
pub fn generation(&self) -> u64 {
|
||||
self.generation
|
||||
}
|
||||
|
||||
pub fn state(&self) -> GeneticState {
|
||||
self.state
|
||||
}
|
||||
|
||||
pub fn process_node(&mut self) -> Result<GeneticState, Error> {
|
||||
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
|
||||
let context = GeneticNodeContext {
|
||||
generation: self.generation,
|
||||
id: self.id,
|
||||
gemla_context,
|
||||
};
|
||||
|
||||
match (self.state, &mut self.node) {
|
||||
(GeneticState::Initialize, _) => {
|
||||
self.node = Some(*T::initialize()?);
|
||||
self.node = Some(*T::initialize(context.clone()).await?);
|
||||
self.state = GeneticState::Simulate;
|
||||
}
|
||||
(GeneticState::Simulate, Some(n)) => {
|
||||
n.simulate()
|
||||
let next_generation = n
|
||||
.simulate(context.clone())
|
||||
.await
|
||||
.with_context(|| format!("Error simulating node: {:?}", self))?;
|
||||
|
||||
self.state = if self.generation >= self.max_generations {
|
||||
GeneticState::Finish
|
||||
} else {
|
||||
info!("Simulation complete and continuing: {:?}", next_generation);
|
||||
|
||||
self.state = if next_generation {
|
||||
GeneticState::Mutate
|
||||
} else {
|
||||
GeneticState::Finish
|
||||
};
|
||||
}
|
||||
(GeneticState::Mutate, Some(n)) => {
|
||||
n.mutate()
|
||||
n.mutate(context.clone())
|
||||
.await
|
||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||
|
||||
self.generation += 1;
|
||||
|
@ -141,40 +177,64 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::error::Error;
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
struct TestState {
|
||||
pub score: f64,
|
||||
pub max_generations: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GeneticNode for TestState {
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
type Context = ();
|
||||
|
||||
async fn simulate(
|
||||
&mut self,
|
||||
context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<bool, Error> {
|
||||
self.score += 1.0;
|
||||
if context.generation >= self.max_generations {
|
||||
Ok(false)
|
||||
} else {
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
|
||||
async fn mutate(
|
||||
&mut self,
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
async fn initialize(
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState {
|
||||
score: 0.0,
|
||||
max_generations: 2,
|
||||
}))
|
||||
}
|
||||
|
||||
fn initialize() -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
fn merge(_l: &TestState, _r: &TestState) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(
|
||||
_l: &TestState,
|
||||
_r: &TestState,
|
||||
_id: &Uuid,
|
||||
_: Self::Context,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Err(Error::Other(anyhow!("Unable to merge")))
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new() -> Result<(), Error> {
|
||||
let genetic_node = GeneticNodeWrapper::<TestState>::new(10);
|
||||
let genetic_node = GeneticNodeWrapper::<TestState>::new();
|
||||
|
||||
let other_genetic_node = GeneticNodeWrapper::<TestState> {
|
||||
node: None,
|
||||
state: GeneticState::Initialize,
|
||||
generation: 1,
|
||||
max_generations: 10,
|
||||
id: genetic_node.id(),
|
||||
};
|
||||
|
||||
|
@ -185,15 +245,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_from() -> Result<(), Error> {
|
||||
let val = TestState { score: 0.0 };
|
||||
let val = TestState {
|
||||
score: 0.0,
|
||||
max_generations: 10,
|
||||
};
|
||||
let uuid = Uuid::new_v4();
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
|
||||
|
||||
let other_genetic_node = GeneticNodeWrapper::<TestState> {
|
||||
node: Some(val),
|
||||
state: GeneticState::Simulate,
|
||||
generation: 1,
|
||||
max_generations: 10,
|
||||
id: genetic_node.id(),
|
||||
};
|
||||
|
||||
|
@ -204,9 +266,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_as_ref() -> Result<(), Error> {
|
||||
let val = TestState { score: 3.0 };
|
||||
let val = TestState {
|
||||
score: 3.0,
|
||||
max_generations: 10,
|
||||
};
|
||||
let uuid = Uuid::new_v4();
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
|
||||
|
||||
let ref_value = genetic_node.as_ref().unwrap();
|
||||
|
||||
|
@ -217,9 +282,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_id() -> Result<(), Error> {
|
||||
let val = TestState { score: 3.0 };
|
||||
let val = TestState {
|
||||
score: 3.0,
|
||||
max_generations: 10,
|
||||
};
|
||||
let uuid = Uuid::new_v4();
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
|
||||
|
||||
let id_value = genetic_node.id();
|
||||
|
||||
|
@ -228,24 +296,14 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_generations() -> Result<(), Error> {
|
||||
let val = TestState { score: 3.0 };
|
||||
let uuid = Uuid::new_v4();
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
|
||||
|
||||
let max_generations = genetic_node.max_generations();
|
||||
|
||||
assert_eq!(max_generations, 10);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_state() -> Result<(), Error> {
|
||||
let val = TestState { score: 3.0 };
|
||||
let val = TestState {
|
||||
score: 3.0,
|
||||
max_generations: 10,
|
||||
};
|
||||
let uuid = Uuid::new_v4();
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
|
||||
let genetic_node = GeneticNodeWrapper::from(val.clone(), uuid);
|
||||
|
||||
let state = genetic_node.state();
|
||||
|
||||
|
@ -254,16 +312,16 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_node() -> Result<(), Error> {
|
||||
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
|
||||
#[tokio::test]
|
||||
async fn test_process_node() -> Result<(), Error> {
|
||||
let mut genetic_node = GeneticNodeWrapper::<TestState>::new();
|
||||
|
||||
assert_eq!(genetic_node.state(), GeneticState::Initialize);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -4,42 +4,44 @@
|
|||
pub mod genetic_node;
|
||||
|
||||
use crate::{error::Error, tree::Tree};
|
||||
use file_linked::FileLinked;
|
||||
use futures::{future, future::BoxFuture};
|
||||
use async_recursion::async_recursion;
|
||||
use file_linked::{constants::data_format::DataFormat, FileLinked};
|
||||
use futures::future;
|
||||
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
||||
use log::{info, trace, warn};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
|
||||
time::Instant,
|
||||
sync::Arc, time::Instant,
|
||||
};
|
||||
use tokio::{sync::RwLock, task::JoinHandle};
|
||||
use uuid::Uuid;
|
||||
|
||||
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
||||
|
||||
/// Provides configuration options for managing a [`Gemla`] object as it executes.
|
||||
///
|
||||
///
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// ```rust,ignore
|
||||
/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
/// struct TestState {
|
||||
/// pub score: f64,
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// impl genetic_node::GeneticNode for TestState {
|
||||
/// fn simulate(&mut self) -> Result<(), Error> {
|
||||
/// self.score += 1.0;
|
||||
/// Ok(())
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// fn mutate(&mut self) -> Result<(), Error> {
|
||||
/// Ok(())
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// fn initialize() -> Result<Box<TestState>, Error> {
|
||||
/// Ok(Box::new(TestState { score: 0.0 }))
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||
/// Ok(Box::new(if left.score > right.score {
|
||||
/// left.clone()
|
||||
|
@ -48,14 +50,13 @@ type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
|
|||
/// }))
|
||||
/// }
|
||||
/// }
|
||||
///
|
||||
///
|
||||
/// fn main() {
|
||||
///
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Serialize, Deserialize, Copy, Clone)]
|
||||
pub struct GemlaConfig {
|
||||
pub generations_per_node: u64,
|
||||
pub overwrite: bool,
|
||||
}
|
||||
|
||||
|
@ -65,79 +66,125 @@ pub struct GemlaConfig {
|
|||
/// individuals.
|
||||
///
|
||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
pub struct Gemla<'a, T>
|
||||
pub struct Gemla<T>
|
||||
where
|
||||
T: Serialize + Clone,
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Clone,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
|
||||
threads: HashMap<Uuid, BoxFuture<'a, Result<GeneticNodeWrapper<T>, Error>>>,
|
||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
|
||||
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a> Gemla<'a, T>
|
||||
impl<T: 'static> Gemla<T>
|
||||
where
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone,
|
||||
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
|
||||
{
|
||||
pub fn new(path: &Path, config: GemlaConfig) -> Result<Self, Error> {
|
||||
pub async fn new(
|
||||
path: &Path,
|
||||
config: GemlaConfig,
|
||||
data_format: DataFormat,
|
||||
) -> Result<Self, Error> {
|
||||
match File::open(path) {
|
||||
// If the file exists we either want to overwrite the file or read from the file
|
||||
// If the file exists we either want to overwrite the file or read from the file
|
||||
// based on the configuration provided
|
||||
Ok(_) => Ok(Gemla {
|
||||
data: if config.overwrite {
|
||||
FileLinked::new((None, config), path)?
|
||||
FileLinked::new((None, config, T::Context::default()), path, data_format)
|
||||
.await?
|
||||
} else {
|
||||
FileLinked::from_file(path)?
|
||||
FileLinked::from_file(path, data_format)?
|
||||
},
|
||||
threads: HashMap::new(),
|
||||
}),
|
||||
// If the file doesn't exist we must create it
|
||||
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
|
||||
data: FileLinked::new((None, config), path)?,
|
||||
data: FileLinked::new((None, config, T::Context::default()), path, data_format)
|
||||
.await?,
|
||||
threads: HashMap::new(),
|
||||
}),
|
||||
Err(error) => Err(Error::IO(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tree_ref(&self) -> Option<&SimulationTree<T>> {
|
||||
self.data.readonly().0.as_ref()
|
||||
pub fn tree_ref(&self) -> Arc<RwLock<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>> {
|
||||
self.data.readonly().clone()
|
||||
}
|
||||
|
||||
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
|
||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||
// in the tree and which nodes have not.
|
||||
self.data.mutate(|(d, c)| {
|
||||
let mut tree: Option<SimulationTree<T>> = Gemla::increase_height(d.take(), c, steps);
|
||||
mem::swap(d, &mut tree);
|
||||
})?;
|
||||
let tree_completed = {
|
||||
// Only increase height if the tree is uninitialized or completed
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let tree_ref = data_ref.0.as_ref();
|
||||
|
||||
info!(
|
||||
"Height of simulation tree increased to {}",
|
||||
self.tree_ref()
|
||||
.map(|t| format!("{}", t.height()))
|
||||
.unwrap_or_else(|| "Tree is not defined".to_string())
|
||||
);
|
||||
tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
|
||||
};
|
||||
|
||||
if tree_completed {
|
||||
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
|
||||
// in the tree and which nodes have not.
|
||||
self.data
|
||||
.mutate(|(d, _, _)| {
|
||||
let mut tree: Option<SimulationTree<T>> =
|
||||
Gemla::increase_height(d.take(), steps);
|
||||
mem::swap(d, &mut tree);
|
||||
})
|
||||
.await?;
|
||||
}
|
||||
|
||||
{
|
||||
// Only increase height if the tree is uninitialized or completed
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let tree_ref = data_ref.0.as_ref();
|
||||
|
||||
info!(
|
||||
"Height of simulation tree increased to {}",
|
||||
tree_ref
|
||||
.map(|t| format!("{}", t.height()))
|
||||
.unwrap_or_else(|| "Tree is not defined".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
loop {
|
||||
// We need to keep simulating until the tree has been completely processed.
|
||||
if self
|
||||
.tree_ref()
|
||||
.map(|t| Gemla::is_completed(t))
|
||||
.unwrap_or(false)
|
||||
let is_tree_processed;
|
||||
|
||||
{
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let tree_ref = data_ref.0.as_ref();
|
||||
|
||||
is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
|
||||
}
|
||||
|
||||
// We need to keep simulating until the tree has been completely processed.
|
||||
if is_tree_processed {
|
||||
self.join_threads().await?;
|
||||
|
||||
info!("Processed tree");
|
||||
break;
|
||||
}
|
||||
|
||||
if let Some(node) = self
|
||||
.tree_ref()
|
||||
.and_then(|t| self.get_unprocessed_node(t))
|
||||
{
|
||||
let (node, gemla_context) = {
|
||||
let data_arc = self.data.readonly();
|
||||
let data_ref = data_arc.read().await;
|
||||
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context)
|
||||
|
||||
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
|
||||
|
||||
(node, gemla_context.clone())
|
||||
};
|
||||
|
||||
if let Some(node) = node {
|
||||
trace!("Adding node to process list {}", node.id());
|
||||
|
||||
self.threads
|
||||
.insert(node.id(), Box::pin(Gemla::process_node(node)));
|
||||
let gemla_context = gemla_context.clone();
|
||||
|
||||
self.threads.insert(
|
||||
node.id(),
|
||||
tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
|
||||
);
|
||||
} else {
|
||||
trace!("No node found to process, joining threads");
|
||||
|
||||
|
@ -153,38 +200,56 @@ where
|
|||
trace!("Joining threads for nodes {:?}", self.threads.keys());
|
||||
|
||||
let results = future::join_all(self.threads.values_mut()).await;
|
||||
|
||||
// Converting a list of results into a result wrapping the list
|
||||
let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> =
|
||||
results.into_iter().collect();
|
||||
results.into_iter().flatten().collect();
|
||||
self.threads.clear();
|
||||
|
||||
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
|
||||
reduced_results.and_then(|r| {
|
||||
self.data.mutate(|(d, _)| {
|
||||
if let Some(t) = d {
|
||||
let failed_nodes = Gemla::replace_nodes(t, r);
|
||||
// We receive a list of nodes that were unable to be found in the original tree
|
||||
if !failed_nodes.is_empty() {
|
||||
warn!(
|
||||
"Unable to find {:?} to replace in tree",
|
||||
failed_nodes.iter().map(|n| n.id())
|
||||
)
|
||||
}
|
||||
match reduced_results {
|
||||
Ok(r) => {
|
||||
self.data
|
||||
.mutate_async(|d| async move {
|
||||
// Scope to limit the duration of the read lock
|
||||
let (_, context) = {
|
||||
let data_read = d.read().await;
|
||||
(data_read.1, data_read.2.clone())
|
||||
}; // Read lock is dropped here
|
||||
|
||||
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
||||
Gemla::merge_completed_nodes(t)
|
||||
} else {
|
||||
warn!("Unable to replce nodes {:?} in empty tree", r);
|
||||
Ok(())
|
||||
}
|
||||
})?
|
||||
})?;
|
||||
let mut data_write = d.write().await;
|
||||
|
||||
if let Some(t) = data_write.0.as_mut() {
|
||||
let failed_nodes = Gemla::replace_nodes(t, r);
|
||||
// We receive a list of nodes that were unable to be found in the original tree
|
||||
if !failed_nodes.is_empty() {
|
||||
warn!(
|
||||
"Unable to find {:?} to replace in tree",
|
||||
failed_nodes.iter().map(|n| n.id())
|
||||
)
|
||||
}
|
||||
|
||||
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
|
||||
Gemla::merge_completed_nodes(t, context.clone()).await
|
||||
} else {
|
||||
warn!("Unable to replce nodes {:?} in empty tree", r);
|
||||
Ok(())
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
}
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn merge_completed_nodes(tree: &mut SimulationTree<T>) -> Result<(), Error> {
|
||||
#[async_recursion]
|
||||
async fn merge_completed_nodes<'a>(
|
||||
tree: &'a mut SimulationTree<T>,
|
||||
gemla_context: T::Context,
|
||||
) -> Result<(), Error> {
|
||||
if tree.val.state() == GeneticState::Initialize {
|
||||
match (&mut tree.left, &mut tree.right) {
|
||||
// If the current node has been initialized, and has children nodes that are completed, then we need
|
||||
|
@ -195,43 +260,37 @@ where
|
|||
{
|
||||
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
|
||||
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
|
||||
let merged_node = GeneticNode::merge(left_node, right_node)?;
|
||||
tree.val = GeneticNodeWrapper::from(
|
||||
*merged_node,
|
||||
tree.val.max_generations(),
|
||||
tree.val.id(),
|
||||
);
|
||||
let merged_node = GeneticNode::merge(
|
||||
left_node,
|
||||
right_node,
|
||||
&tree.val.id(),
|
||||
gemla_context.clone(),
|
||||
)
|
||||
.await?;
|
||||
tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.id());
|
||||
}
|
||||
}
|
||||
(Some(l), Some(r)) => {
|
||||
Gemla::merge_completed_nodes(l)?;
|
||||
Gemla::merge_completed_nodes(r)?;
|
||||
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
|
||||
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
|
||||
}
|
||||
// If there is only one child node that's completed then we want to copy it to the parent node
|
||||
(Some(l), None) if l.val.state() == GeneticState::Finish => {
|
||||
trace!("Copying node {}", l.val.id());
|
||||
|
||||
if let Some(left_node) = l.val.as_ref() {
|
||||
GeneticNodeWrapper::from(
|
||||
left_node.clone(),
|
||||
tree.val.max_generations(),
|
||||
tree.val.id(),
|
||||
);
|
||||
GeneticNodeWrapper::from(left_node.clone(), tree.val.id());
|
||||
}
|
||||
}
|
||||
(Some(l), None) => Gemla::merge_completed_nodes(l)?,
|
||||
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
|
||||
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
|
||||
trace!("Copying node {}", r.val.id());
|
||||
|
||||
if let Some(right_node) = r.val.as_ref() {
|
||||
tree.val = GeneticNodeWrapper::from(
|
||||
right_node.clone(),
|
||||
tree.val.max_generations(),
|
||||
tree.val.id(),
|
||||
);
|
||||
tree.val = GeneticNodeWrapper::from(right_node.clone(), tree.val.id());
|
||||
}
|
||||
}
|
||||
(None, Some(r)) => Gemla::merge_completed_nodes(r)?,
|
||||
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
|
||||
(_, _) => (),
|
||||
}
|
||||
}
|
||||
|
@ -240,15 +299,18 @@ where
|
|||
}
|
||||
|
||||
fn get_unprocessed_node(&self, tree: &SimulationTree<T>) -> Option<GeneticNodeWrapper<T>> {
|
||||
// If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list
|
||||
// If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list
|
||||
// should be fine because we process the tree from bottom to top.
|
||||
if tree.val.state() != GeneticState::Finish && !self.threads.contains_key(&tree.val.id()) {
|
||||
match (&tree.left, &tree.right) {
|
||||
// If the children are finished we can start processing the currrent node. The current node should be merged from the children already
|
||||
// If the children are finished we can start processing the currrent node. The current node should be merged from the children already
|
||||
// during join_threads.
|
||||
(Some(l), Some(r))
|
||||
if l.val.state() == GeneticState::Finish
|
||||
&& r.val.state() == GeneticState::Finish => Some(tree.val.clone()),
|
||||
&& r.val.state() == GeneticState::Finish =>
|
||||
{
|
||||
Some(tree.val.clone())
|
||||
}
|
||||
(Some(l), Some(r)) => self
|
||||
.get_unprocessed_node(l)
|
||||
.or_else(|| self.get_unprocessed_node(r)),
|
||||
|
@ -278,25 +340,19 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn increase_height(
|
||||
tree: Option<SimulationTree<T>>,
|
||||
config: &GemlaConfig,
|
||||
amount: u64,
|
||||
) -> Option<SimulationTree<T>> {
|
||||
fn increase_height(tree: Option<SimulationTree<T>>, amount: u64) -> Option<SimulationTree<T>> {
|
||||
if amount == 0 {
|
||||
tree
|
||||
} else {
|
||||
let left_branch_right =
|
||||
let left_branch_height =
|
||||
tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1;
|
||||
|
||||
|
||||
Some(Box::new(Tree::new(
|
||||
GeneticNodeWrapper::new(config.generations_per_node),
|
||||
Gemla::increase_height(tree, config, amount - 1),
|
||||
GeneticNodeWrapper::new(),
|
||||
Gemla::increase_height(tree, amount - 1),
|
||||
// The right branch height has to equal the left branches total height
|
||||
if left_branch_right > 0 {
|
||||
Some(Box::new(btree!(GeneticNodeWrapper::new(
|
||||
left_branch_right * config.generations_per_node
|
||||
))))
|
||||
if left_branch_height > 0 {
|
||||
Some(Box::new(btree!(GeneticNodeWrapper::new())))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
|
@ -306,16 +362,19 @@ where
|
|||
|
||||
fn is_completed(tree: &SimulationTree<T>) -> bool {
|
||||
// If the current node is finished, then by convention the children should all be finished as well
|
||||
tree.val.state() == GeneticState::Finish
|
||||
tree.val.state() == GeneticState::Finish
|
||||
}
|
||||
|
||||
async fn process_node(mut node: GeneticNodeWrapper<T>) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
async fn process_node(
|
||||
mut node: GeneticNodeWrapper<T>,
|
||||
gemla_context: T::Context,
|
||||
) -> Result<GeneticNodeWrapper<T>, Error> {
|
||||
let node_state_time = Instant::now();
|
||||
let node_state = node.state();
|
||||
|
||||
node.process_node()?;
|
||||
node.process_node(gemla_context.clone()).await?;
|
||||
|
||||
trace!(
|
||||
info!(
|
||||
"{:?} completed in {:?} for {}",
|
||||
node_state,
|
||||
node_state_time.elapsed(),
|
||||
|
@ -333,9 +392,13 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::core::*;
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use self::genetic_node::GeneticNodeContext;
|
||||
|
||||
struct CleanUp {
|
||||
path: PathBuf,
|
||||
|
@ -364,23 +427,43 @@ mod tests {
|
|||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
struct TestState {
|
||||
pub score: f64,
|
||||
pub max_generations: u64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl genetic_node::GeneticNode for TestState {
|
||||
fn simulate(&mut self) -> Result<(), Error> {
|
||||
type Context = ();
|
||||
|
||||
async fn simulate(
|
||||
&mut self,
|
||||
context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<bool, Error> {
|
||||
self.score += 1.0;
|
||||
Ok(context.generation < self.max_generations)
|
||||
}
|
||||
|
||||
async fn mutate(
|
||||
&mut self,
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self) -> Result<(), Error> {
|
||||
Ok(())
|
||||
async fn initialize(
|
||||
_context: GeneticNodeContext<Self::Context>,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState {
|
||||
score: 0.0,
|
||||
max_generations: 10,
|
||||
}))
|
||||
}
|
||||
|
||||
fn initialize() -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
|
||||
async fn merge(
|
||||
left: &TestState,
|
||||
right: &TestState,
|
||||
_id: &Uuid,
|
||||
_: Self::Context,
|
||||
) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(if left.score > right.score {
|
||||
left.clone()
|
||||
} else {
|
||||
|
@ -389,66 +472,93 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_new() -> Result<(), Error> {
|
||||
let path = PathBuf::from("test_new_non_existing");
|
||||
CleanUp::new(&path).run(|p| {
|
||||
assert!(!path.exists());
|
||||
// Use `spawn_blocking` to run synchronous code that needs to call async code internally.
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
|
||||
CleanUp::new(&path).run(move |p| {
|
||||
rt.block_on(async {
|
||||
assert!(!path.exists());
|
||||
|
||||
// Testing initial creation
|
||||
let mut config = GemlaConfig {
|
||||
generations_per_node: 1,
|
||||
overwrite: true
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config)?;
|
||||
// Testing initial creation
|
||||
let mut config = GemlaConfig { overwrite: true };
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
|
||||
|
||||
smol::block_on(gemla.simulate(2))?;
|
||||
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
|
||||
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
// Now we can use `.await` within the spawned blocking task.
|
||||
gemla.simulate(2).await?;
|
||||
let data = gemla.data.readonly();
|
||||
let data_lock = data.read().await;
|
||||
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
|
||||
|
||||
// Testing overwriting data
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config)?;
|
||||
drop(data_lock);
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
|
||||
smol::block_on(gemla.simulate(2))?;
|
||||
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
|
||||
// Testing overwriting data
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
|
||||
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
gemla.simulate(2).await?;
|
||||
let data = gemla.data.readonly();
|
||||
let data_lock = data.read().await;
|
||||
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
|
||||
|
||||
// Testing not-overwriting data
|
||||
config.overwrite = false;
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config)?;
|
||||
drop(data_lock);
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
|
||||
smol::block_on(gemla.simulate(2))?;
|
||||
assert_eq!(gemla.tree_ref().unwrap().height(), 4);
|
||||
// Testing not-overwriting data
|
||||
config.overwrite = false;
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
|
||||
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
gemla.simulate(2).await?;
|
||||
let data = gemla.data.readonly();
|
||||
let data_lock = data.read().await;
|
||||
let tree = data_lock.0.as_ref().unwrap();
|
||||
assert_eq!(tree.height(), 4);
|
||||
|
||||
Ok(())
|
||||
drop(data_lock);
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simulate() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_simulate() -> Result<(), Error> {
|
||||
let path = PathBuf::from("test_simulate");
|
||||
CleanUp::new(&path).run(|p| {
|
||||
// Testing initial creation
|
||||
let config = GemlaConfig {
|
||||
generations_per_node: 10,
|
||||
overwrite: true
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config)?;
|
||||
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
|
||||
CleanUp::new(&path).run(move |p| {
|
||||
rt.block_on(async {
|
||||
// Testing initial creation
|
||||
let config = GemlaConfig { overwrite: true };
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
|
||||
|
||||
smol::block_on(gemla.simulate(5))?;
|
||||
let tree = gemla.tree_ref().unwrap();
|
||||
assert_eq!(tree.height(), 5);
|
||||
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
|
||||
// Now we can use `.await` within the spawned blocking task.
|
||||
gemla.simulate(5).await?;
|
||||
let data = gemla.data.readonly();
|
||||
let data_lock = data.read().await;
|
||||
let tree = data_lock.0.as_ref().unwrap();
|
||||
assert_eq!(tree.height(), 5);
|
||||
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
})
|
||||
}
|
||||
.await
|
||||
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
#[macro_use]
|
||||
pub mod tree;
|
||||
pub mod constants;
|
||||
pub mod core;
|
||||
pub mod error;
|
||||
|
|
|
@ -36,7 +36,7 @@ use std::cmp::max;
|
|||
/// t.right = Some(Box::new(btree!(3)));
|
||||
/// assert_eq!(t.right.unwrap().val, 3);
|
||||
/// ```
|
||||
#[derive(Default, Serialize, Deserialize, Clone, PartialEq, Debug)]
|
||||
#[derive(Default, Serialize, Deserialize, PartialEq, Debug)]
|
||||
pub struct Tree<T> {
|
||||
pub val: T,
|
||||
pub left: Option<Box<Tree<T>>>,
|
||||
|
|
380
parameter_analysis.py
Normal file
380
parameter_analysis.py
Normal file
|
@ -0,0 +1,380 @@
|
|||
# Re-importing necessary libraries
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.ticker as ticker
|
||||
|
||||
# Simplified JSON data for demonstration
|
||||
with open('gemla/round4.json', 'r') as file:
|
||||
simplified_json_data = json.load(file)
|
||||
|
||||
# Function to traverse the tree to find a node id
|
||||
def traverse_right_nodes(node):
|
||||
if node is None:
|
||||
return []
|
||||
|
||||
right_node = node.get("right")
|
||||
left_node = node.get("left")
|
||||
|
||||
if right_node is None and left_node is None:
|
||||
return []
|
||||
elif right_node and left_node:
|
||||
return [right_node] + traverse_right_nodes(left_node)
|
||||
|
||||
return []
|
||||
|
||||
# Getting most recent right graph
|
||||
right_nodes = traverse_right_nodes(simplified_json_data[0])
|
||||
|
||||
# Heatmaps
|
||||
# Data structure to store mutation rates, generations, and scores
|
||||
mutation_rate_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Populate the dictionary with scores indexed by mutation rate and generation
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
minor_mutation_rate = node_val["minor_mutation_rate"]
|
||||
generation = node_val["generation"]
|
||||
# Ensure each score is associated with the correct generation
|
||||
for gen_index, score_list in enumerate(scores):
|
||||
for score in score_list.values():
|
||||
mutation_rate_data[minor_mutation_rate][gen_index].append(score)
|
||||
|
||||
# Prepare data for heatmap
|
||||
max_generation = max(max(gens.keys()) for gens in mutation_rate_data.values())
|
||||
heatmap_data = np.full((len(mutation_rate_data), max_generation + 1), np.nan)
|
||||
|
||||
# Populate the heatmap data with average scores
|
||||
mutation_rates = sorted(mutation_rate_data.keys())
|
||||
for i, mutation_rate in enumerate(mutation_rates):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = mutation_rate_data[mutation_rate][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the heatmap
|
||||
df_heatmap = pd.DataFrame(
|
||||
data=heatmap_data,
|
||||
index=mutation_rates,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# Data structure to store major mutation rates, generations, and scores
|
||||
major_mutation_rate_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Populate the dictionary with scores indexed by major mutation rate and generation
|
||||
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
major_mutation_rate = node_val["major_mutation_rate"]
|
||||
generation = node_val["generation"]
|
||||
for gen_index, score_list in enumerate(scores):
|
||||
for score in score_list.values():
|
||||
major_mutation_rate_data[major_mutation_rate][gen_index].append(score)
|
||||
|
||||
# Prepare the heatmap data for major_mutation_rate similar to minor_mutation_rate
|
||||
major_heatmap_data = np.full((len(major_mutation_rate_data), max_generation + 1), np.nan)
|
||||
major_mutation_rates = sorted(major_mutation_rate_data.keys())
|
||||
|
||||
for i, major_rate in enumerate(major_mutation_rates):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = major_mutation_rate_data[major_rate][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
major_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_major_heatmap = pd.DataFrame(
|
||||
data=major_heatmap_data,
|
||||
index=major_mutation_rates,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# crossbreed_segments
|
||||
# Data structure to store major mutation rates, generations, and scores
|
||||
crossbreed_segments_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Populate the dictionary with scores indexed by major mutation rate and generation
|
||||
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
crossbreed_segments = node_val["crossbreed_segments"]
|
||||
generation = node_val["generation"]
|
||||
for gen_index, score_list in enumerate(scores):
|
||||
for score in score_list.values():
|
||||
crossbreed_segments_data[crossbreed_segments][gen_index].append(score)
|
||||
|
||||
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
|
||||
crossbreed_heatmap_data = np.full((len(crossbreed_segments_data), max_generation + 1), np.nan)
|
||||
crossbreed_segments = sorted(crossbreed_segments_data.keys())
|
||||
|
||||
for i, crossbreed_segment in enumerate(crossbreed_segments):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = crossbreed_segments_data[crossbreed_segment][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
crossbreed_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_crossbreed_heatmap = pd.DataFrame(
|
||||
data=crossbreed_heatmap_data,
|
||||
index=crossbreed_segments,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# mutation_weight_range
|
||||
# Data structure to store major mutation rates, generations, and scores
|
||||
mutation_weight_range_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Populate the dictionary with scores indexed by major mutation rate and generation
|
||||
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
mutation_weight_range = node_val["mutation_weight_range"]
|
||||
positive_extent = mutation_weight_range["end"]
|
||||
negative_extent = -mutation_weight_range["start"]
|
||||
mutation_weight_range = (positive_extent + negative_extent) / 2
|
||||
generation = node_val["generation"]
|
||||
for gen_index, score_list in enumerate(scores):
|
||||
for score in score_list.values():
|
||||
mutation_weight_range_data[mutation_weight_range][gen_index].append(score)
|
||||
|
||||
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
|
||||
mutation_weight_range_heatmap_data = np.full((len(mutation_weight_range_data), max_generation + 1), np.nan)
|
||||
mutation_weight_ranges = sorted(mutation_weight_range_data.keys())
|
||||
|
||||
for i, mutation_weight_range in enumerate(mutation_weight_ranges):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = mutation_weight_range_data[mutation_weight_range][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
mutation_weight_range_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_mutation_weight_range_heatmap = pd.DataFrame(
|
||||
data=mutation_weight_range_heatmap_data,
|
||||
index=mutation_weight_ranges,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# weight_initialization_range
|
||||
# Data structure to store major mutation rates, generations, and scores
|
||||
weight_initialization_range_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Populate the dictionary with scores indexed by major mutation rate and generation
|
||||
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
weight_initialization_range = node_val["weight_initialization_range"]
|
||||
positive_extent = weight_initialization_range["end"]
|
||||
negative_extent = -weight_initialization_range["start"]
|
||||
weight_initialization_range = (positive_extent + negative_extent) / 2
|
||||
generation = node_val["generation"]
|
||||
for gen_index, score_list in enumerate(scores):
|
||||
for score in score_list.values():
|
||||
weight_initialization_range_data[weight_initialization_range][gen_index].append(score)
|
||||
|
||||
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
|
||||
weight_initialization_range_heatmap_data = np.full((len(weight_initialization_range_data), max_generation + 1), np.nan)
|
||||
weight_initialization_ranges = sorted(weight_initialization_range_data.keys())
|
||||
|
||||
for i, weight_initialization_range in enumerate(weight_initialization_ranges):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = weight_initialization_range_data[weight_initialization_range][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
weight_initialization_range_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_weight_initialization_range_heatmap = pd.DataFrame(
|
||||
data=weight_initialization_range_heatmap_data,
|
||||
index=weight_initialization_ranges,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# weight_initialization_range_skew
|
||||
# Data structure to store major mutation rates, generations, and scores
|
||||
weight_initialization_range_skew_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Populate the dictionary with scores indexed by major mutation rate and generation
|
||||
# This is assuming the structure to retrieve major_mutation_rate is similar to minor_mutation_rate
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
weight_initialization_range = node_val["weight_initialization_range"]
|
||||
positive_extent = weight_initialization_range["end"]
|
||||
negative_extent = -weight_initialization_range["start"]
|
||||
weight_initialization_range_skew = (positive_extent - negative_extent) / 2
|
||||
generation = node_val["generation"]
|
||||
for gen_index, score_list in enumerate(scores):
|
||||
for score in score_list.values():
|
||||
weight_initialization_range_skew_data[weight_initialization_range_skew][gen_index].append(score)
|
||||
|
||||
# Prepare the heatmap data for crossbreed_segments similar to minor_mutation_rate
|
||||
weight_initialization_range_skew_heatmap_data = np.full((len(weight_initialization_range_skew_data), max_generation + 1), np.nan)
|
||||
weight_initialization_range_skews = sorted(weight_initialization_range_skew_data.keys())
|
||||
|
||||
for i, weight_initialization_range_skew in enumerate(weight_initialization_range_skews):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = weight_initialization_range_skew_data[weight_initialization_range_skew][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
weight_initialization_range_skew_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_weight_initialization_range_skew_heatmap = pd.DataFrame(
|
||||
data=weight_initialization_range_skew_heatmap_data,
|
||||
index=weight_initialization_range_skews,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# Analyze number of neurons correlation to score
|
||||
# We can get the number of neurons via node_val["nn_shapes"] which contains an array of maps
|
||||
# Each map has a key for the individual id and a value which is an array of integers representing the number of neurons in each layer
|
||||
# We can use the individual id to get the score from the scores array
|
||||
# We then generate a density map of the number of neurons vs the score
|
||||
neuron_number_score_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
nn_shapes = node_val["nn_shapes"]
|
||||
# Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation)
|
||||
for gen_index, score in enumerate(scores):
|
||||
for individual_id, nn_shape in nn_shapes[gen_index].items():
|
||||
neuron_number = sum(nn_shape)
|
||||
# check if score has a value for the individual id
|
||||
if individual_id not in score:
|
||||
continue
|
||||
neuron_number_score_data[neuron_number][gen_index].append(score[individual_id])
|
||||
|
||||
# prepare the density map data
|
||||
neuron_number_score_heatmap_data = np.full((len(neuron_number_score_data), max_generation + 1), np.nan)
|
||||
neuron_numbers = sorted(neuron_number_score_data.keys())
|
||||
|
||||
for i, neuron_number in enumerate(neuron_numbers):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = neuron_number_score_data[neuron_number][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
neuron_number_score_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_neuron_number_score_heatmap = pd.DataFrame(
|
||||
data=neuron_number_score_heatmap_data,
|
||||
index=neuron_numbers,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# Analyze number of layers correlation to score
|
||||
nn_layers_score_data = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
for node in right_nodes:
|
||||
node_val = node["val"]["node"]
|
||||
if node_val:
|
||||
scores = node_val["scores"]
|
||||
nn_shapes = node_val["nn_shapes"]
|
||||
# Both scores and nn_shapes are arrays where score is 1 less in length than nn_shapes (each index corresponds to a generation)
|
||||
for gen_index, score in enumerate(scores):
|
||||
for individual_id, nn_shape in nn_shapes[gen_index].items():
|
||||
layer_number = len(nn_shape)
|
||||
# check if score has a value for the individual id
|
||||
if individual_id not in score:
|
||||
continue
|
||||
nn_layers_score_data[layer_number][gen_index].append(score[individual_id])
|
||||
|
||||
# prepare the density map data
|
||||
nn_layers_score_heatmap_data = np.full((len(nn_layers_score_data), max_generation + 1), np.nan)
|
||||
nn_layers = sorted(nn_layers_score_data.keys())
|
||||
|
||||
for i, nn_layer in enumerate(nn_layers):
|
||||
for generation in range(max_generation + 1):
|
||||
scores = nn_layers_score_data[nn_layer][generation]
|
||||
if scores: # Check if there are scores for this generation
|
||||
nn_layers_score_heatmap_data[i, generation] = np.mean(scores)
|
||||
|
||||
# Creating a DataFrame for the major mutation rate heatmap
|
||||
df_nn_layers_score_heatmap = pd.DataFrame(
|
||||
data=nn_layers_score_heatmap_data,
|
||||
index=nn_layers,
|
||||
columns=range(max_generation + 1)
|
||||
)
|
||||
|
||||
# print("Format: ", custom_formatter(0.123498761234, 0))
|
||||
|
||||
# Creating subplots
|
||||
fig, axs = plt.subplots(2, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots
|
||||
|
||||
# Plotting the minor mutation rate heatmap
|
||||
sns.heatmap(df_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 0])
|
||||
# axs[0, 0].set_title('Minor Mutation Rate')
|
||||
axs[0, 0].set_xlabel('Minor Mutation Rate')
|
||||
axs[0, 0].set_ylabel('Generation')
|
||||
axs[0, 0].invert_yaxis()
|
||||
|
||||
# Plotting the major mutation rate heatmap
|
||||
sns.heatmap(df_major_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[0, 1])
|
||||
# axs[0, 1].set_title('Major Mutation Rate')
|
||||
axs[0, 1].set_xlabel('Major Mutation Rate')
|
||||
axs[0, 1].invert_yaxis()
|
||||
|
||||
# Plotting the crossbreed_segments heatmap
|
||||
sns.heatmap(df_crossbreed_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 0])
|
||||
# axs[1, 0].set_title('Crossbreed Segments')
|
||||
axs[1, 0].set_xlabel('Crossbreed Segments')
|
||||
axs[1, 0].set_ylabel('Generation')
|
||||
axs[1, 0].invert_yaxis()
|
||||
|
||||
# Plotting the mutation_weight_range heatmap
|
||||
sns.heatmap(df_mutation_weight_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs[1, 1])
|
||||
# axs[1, 1].set_title('Mutation Weight Range')
|
||||
axs[1, 1].set_xlabel('Mutation Weight Range')
|
||||
axs[1, 1].invert_yaxis()
|
||||
|
||||
fig3, axs3 = plt.subplots(1, 2, figsize=(20, 14)) # Creates a 3x2 grid of subplots
|
||||
|
||||
# Plotting the weight_initialization_range heatmap
|
||||
sns.heatmap(df_weight_initialization_range_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[0])
|
||||
# axs[2, 0].set_title('Weight Initialization Range')
|
||||
axs3[0].set_xlabel('Weight Initialization Range')
|
||||
axs3[0].set_ylabel('Generation')
|
||||
axs3[0].invert_yaxis()
|
||||
|
||||
# Plotting the weight_initialization_range_skew heatmap
|
||||
sns.heatmap(df_weight_initialization_range_skew_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs3[1])
|
||||
# axs[2, 1].set_title('Weight Initialization Range Skew')
|
||||
axs3[1].set_xlabel('Weight Initialization Range Skew')
|
||||
axs3[1].set_ylabel('Generation')
|
||||
axs3[1].invert_yaxis()
|
||||
|
||||
# Creating a new window for the scatter plots
|
||||
fig2, axs2 = plt.subplots(2, 1, figsize=(20, 14)) # Creates a 2x1 grid of subplots
|
||||
|
||||
# Plotting the neuron number vs score heatmap
|
||||
sns.heatmap(df_neuron_number_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[1])
|
||||
# axs[3, 1].set_title('Neuron Number vs. Score')
|
||||
axs2[1].set_xlabel('Neuron Number')
|
||||
axs2[1].set_ylabel('Generation')
|
||||
axs2[1].invert_yaxis()
|
||||
|
||||
# Plotting the number of layers vs score heatmap
|
||||
sns.heatmap(df_nn_layers_score_heatmap.T, cmap='viridis', fmt=".4g", cbar_kws={'label': 'Mean Score'}, ax=axs2[0])
|
||||
# axs[3, 1].set_title('Number of Layers vs. Score')
|
||||
axs2[0].set_xlabel('Number of Layers')
|
||||
axs2[0].set_ylabel('Generation')
|
||||
axs2[0].invert_yaxis()
|
||||
|
||||
# Display the plot
|
||||
plt.tight_layout() # Adjusts the subplots to fit into the figure area.
|
||||
plt.show()
|
|
@ -1,12 +0,0 @@
|
|||
;; Define a type that contains a population size and a population cutoff
|
||||
(defclass simulation-node () ((population-size :initarg :population-size :accessor population-size)
|
||||
(population-cutoff :initarg :population-cutoff :accessor population-cutoff)
|
||||
(population :initform () :accessor population)))
|
||||
|
||||
;; Define a method that initializes population-size number of children in a population each with a random value
|
||||
(defmethod initialize-instance :after ((node simulation-node) &key)
|
||||
(setf (population node) (make-list (population-size node) :initial-element (random 100))))
|
||||
|
||||
(let ((node (make-instance 'simulation-node :population-size 100 :population-cutoff 10)))
|
||||
(print (population-size node))
|
||||
(population node))
|
118
visualize_networks.py
Normal file
118
visualize_networks.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import subprocess
|
||||
import tkinter as tk
|
||||
from tkinter import filedialog
|
||||
|
||||
def select_file():
|
||||
root = tk.Tk()
|
||||
root.withdraw() # Hide the main window
|
||||
file_path = filedialog.askopenfilename(
|
||||
initialdir="/", # Set the initial directory to search for files
|
||||
title="Select file",
|
||||
filetypes=(("Net files", "*.net"), ("All files", "*.*"))
|
||||
)
|
||||
return file_path
|
||||
|
||||
def get_fann_data(network_file):
|
||||
# Adjust the path to the Rust executable as needed
|
||||
result = subprocess.run(['./extract_fann_data/target/debug/extract_fann_data.exe', network_file], capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
print("Error:", result.stderr)
|
||||
return None, None
|
||||
|
||||
layer_sizes = []
|
||||
connections = []
|
||||
parsing_connections = False
|
||||
|
||||
for line in result.stdout.splitlines():
|
||||
if line.startswith("Layers:"):
|
||||
continue
|
||||
elif line.startswith("Connections:"):
|
||||
parsing_connections = True
|
||||
continue
|
||||
|
||||
if parsing_connections:
|
||||
from_neuron, to_neuron, weight = map(float, line.split())
|
||||
connections.append((int(from_neuron), int(to_neuron), weight))
|
||||
else:
|
||||
layer_size, bias_count = map(int, line.split())
|
||||
layer_sizes.append((layer_size, bias_count))
|
||||
|
||||
return layer_sizes, connections
|
||||
|
||||
def visualize_fann_network(network_file):
|
||||
# Get network data
|
||||
layer_sizes, connections = get_fann_data(network_file)
|
||||
if layer_sizes is None or connections is None:
|
||||
return # Error handling in get_fann_data should provide error output
|
||||
|
||||
# Create a directed graph
|
||||
G = nx.DiGraph()
|
||||
|
||||
# Positions dictionary to hold the position of each neuron
|
||||
pos = {}
|
||||
node_count = 0
|
||||
x_spacing = 1.0
|
||||
y_spacing = 1.0
|
||||
|
||||
# Calculate the maximum layer size for proper spacing
|
||||
max_layer_size = max(size for size, bias in layer_sizes)
|
||||
|
||||
# Build nodes and position them layer by layer from left to right
|
||||
for layer_index, (layer_size, bias_count) in enumerate(layer_sizes):
|
||||
y_positions = list(range(-layer_size-bias_count+1, 1, 1)) # Center-align vertically
|
||||
y_positions = [y * (max_layer_size / (layer_size + bias_count)) * y_spacing for y in y_positions] # Adjust spacing
|
||||
for neuron_index in range(layer_size + bias_count): # Include bias neurons
|
||||
node_label = f"L{layer_index}N{neuron_index}"
|
||||
G.add_node(node_count, label=node_label)
|
||||
pos[node_count] = (layer_index * x_spacing, y_positions[neuron_index % len(y_positions)])
|
||||
node_count += 1
|
||||
|
||||
# Add connections to the graph
|
||||
for from_neuron, to_neuron, weight in connections:
|
||||
G.add_edge(from_neuron, to_neuron, weight=weight)
|
||||
|
||||
max_weight = max(abs(weight) for _, _, weight in connections)
|
||||
print(f"Max weight: {max_weight}")
|
||||
|
||||
# Draw nodes
|
||||
nx.draw_networkx_nodes(G, pos, node_color='skyblue', node_size=200)
|
||||
nx.draw_networkx_labels(G, pos, font_size=7)
|
||||
|
||||
# Custom function for edge properties
|
||||
def adjust_properties(weight):
|
||||
# if weight > 0:
|
||||
# print("Weight:", weight)
|
||||
color = 'green' if weight > 0 else 'red'
|
||||
alpha = min((abs(weight) / max_weight) ** 3, 1)
|
||||
# print(f"Color: {color}, Alpha: {alpha}")
|
||||
return color, alpha
|
||||
|
||||
# Draw edges with custom properties
|
||||
for u, v, d in G.edges(data=True):
|
||||
color, alpha = adjust_properties(d['weight'])
|
||||
nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], edge_color=color, alpha=alpha, width=1.5, arrows=False)
|
||||
|
||||
# Show plot
|
||||
plt.title('FANN Network Visualization')
|
||||
plt.axis('off') # Turn off the axis
|
||||
plt.show()
|
||||
|
||||
# Path to the FANN network file
|
||||
fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_4f2be613-ab26-4384-9a65-450e043984ea\\6\\4f2be613-ab26-4384-9a65-450e043984ea_fighter_nn_0.net'
|
||||
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_fc294503-7b2a-40f8-be59-ccc486eb3f79\\0\\fc294503-7b2a-40f8-be59-ccc486eb3f79_fighter_nn_0.net"
|
||||
# fann_path = 'F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_99c30a7f-40ab-4faf-b16a-b44703fdb6cd\\0\\99c30a7f-40ab-4faf-b16a-b44703fdb6cd_fighter_nn_0.net'
|
||||
# Has a 4 layer network
|
||||
# # Generation 1
|
||||
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\1\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
|
||||
# # Generation 5
|
||||
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\5\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
|
||||
# # Generation 10
|
||||
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\10\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
|
||||
# # Generation 20
|
||||
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\20\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
|
||||
# # Generation 32
|
||||
# fann_path = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations\\fighter_nn_16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98\\32\\16dfa1b4-03c7-45a6-84b4-22fe3c8e2d98_fighter_nn_0.net"
|
||||
fann_path = select_file()
|
||||
visualize_fann_network(fann_path)
|
104
visualize_simulation_tree.py
Normal file
104
visualize_simulation_tree.py
Normal file
|
@ -0,0 +1,104 @@
|
|||
# Re-importing necessary libraries
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import random
|
||||
|
||||
def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
|
||||
if not nx.is_tree(G):
|
||||
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
|
||||
|
||||
if root is None:
|
||||
if isinstance(G, nx.DiGraph):
|
||||
root = next(iter(nx.topological_sort(G)))
|
||||
else:
|
||||
root = random.choice(list(G.nodes))
|
||||
|
||||
def _hierarchy_pos(G, root, width=2., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
|
||||
if pos is None:
|
||||
pos = {root: (xcenter, vert_loc)}
|
||||
else:
|
||||
pos[root] = (xcenter, vert_loc)
|
||||
children = list(G.successors(root)) # Use successors to get children for DiGraph
|
||||
if not isinstance(G, nx.DiGraph):
|
||||
if parent is not None:
|
||||
children.remove(parent)
|
||||
if len(children) != 0:
|
||||
dx = width / len(children)
|
||||
nextx = xcenter - width / 2 - dx / 2
|
||||
for child in children:
|
||||
nextx += dx
|
||||
pos = _hierarchy_pos(G, child, width=dx*2.0, vert_gap=vert_gap,
|
||||
vert_loc=vert_loc - vert_gap, xcenter=nextx,
|
||||
pos=pos, parent=root)
|
||||
return pos
|
||||
|
||||
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
|
||||
|
||||
# Simplified JSON data for demonstration
|
||||
with open('gemla/round4.json', 'r') as file:
|
||||
simplified_json_data = json.load(file)
|
||||
|
||||
# Function to traverse the tree and create a graph
|
||||
def traverse(node, graph, parent=None):
|
||||
if node is None:
|
||||
return
|
||||
|
||||
node_id = node["val"]["id"]
|
||||
if "node" in node["val"] and node["val"]["node"]:
|
||||
scores = node["val"]["node"]["scores"]
|
||||
generations = node["val"]["node"]["generation"]
|
||||
population_size = node["val"]["node"]["population_size"]
|
||||
# Prepare to track the highest score across all generations and the corresponding individual
|
||||
overall_max_score = float('-inf')
|
||||
overall_max_score_individual = None
|
||||
overall_max_score_gen = None
|
||||
|
||||
for gen, gen_scores in enumerate(scores):
|
||||
if gen_scores: # Ensure the dictionary is not empty
|
||||
# Find the max score and the individual for this generation
|
||||
max_score_for_gen = max(gen_scores.values())
|
||||
individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get)
|
||||
|
||||
# if max_score_for_gen > overall_max_score:
|
||||
overall_max_score = max_score_for_gen
|
||||
overall_max_score_individual = individual_with_max_score_for_gen
|
||||
overall_max_score_gen = gen
|
||||
|
||||
# print debug statement
|
||||
# print(f"Node {node_id}: Max score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})")
|
||||
# print(f"Left: {node.get('left')}, Right: {node.get('right')}")
|
||||
label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen + 1 if overall_max_score_gen is not None else 'N/A'})"
|
||||
else:
|
||||
label = node_id
|
||||
|
||||
graph.add_node(node_id, label=label)
|
||||
if parent:
|
||||
graph.add_edge(parent, node_id)
|
||||
|
||||
traverse(node.get("left"), graph, parent=node_id)
|
||||
traverse(node.get("right"), graph, parent=node_id)
|
||||
|
||||
|
||||
# Create a directed graph
|
||||
G = nx.DiGraph()
|
||||
|
||||
# Populate the graph
|
||||
traverse(simplified_json_data[0], G)
|
||||
|
||||
# Find the root node (a node with no incoming edges)
|
||||
root_candidates = [node for node, indeg in G.in_degree() if indeg == 0]
|
||||
|
||||
if root_candidates:
|
||||
root_node = root_candidates[0] # Assuming there's only one root candidate
|
||||
else:
|
||||
root_node = None # This should ideally never happen in a properly structured tree
|
||||
|
||||
# Use the determined root node for hierarchy_pos
|
||||
if root_node is not None:
|
||||
pos = hierarchy_pos(G, root=root_node)
|
||||
labels = nx.get_node_attributes(G, 'label')
|
||||
nx.draw(G, pos, labels=labels, with_labels=True, arrows=True)
|
||||
plt.show()
|
||||
else:
|
||||
print("No root node found. Cannot draw the tree.")
|
Loading…
Add table
Reference in a new issue