Modifying executor runtime for efficiency
This commit is contained in:
parent
c9b746e59d
commit
97086fdbe0
7 changed files with 356 additions and 157 deletions
101
analysis.py
Normal file
101
analysis.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
# Re-importing necessary libraries
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import networkx as nx
|
||||
import random
|
||||
|
||||
def hierarchy_pos(G, root=None, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5):
|
||||
if not nx.is_tree(G):
|
||||
raise TypeError('cannot use hierarchy_pos on a graph that is not a tree')
|
||||
|
||||
if root is None:
|
||||
if isinstance(G, nx.DiGraph):
|
||||
root = next(iter(nx.topological_sort(G)))
|
||||
else:
|
||||
root = random.choice(list(G.nodes))
|
||||
|
||||
def _hierarchy_pos(G, root, width=1., vert_gap=0.2, vert_loc=0, xcenter=0.5, pos=None, parent=None):
|
||||
if pos is None:
|
||||
pos = {root: (xcenter, vert_loc)}
|
||||
else:
|
||||
pos[root] = (xcenter, vert_loc)
|
||||
children = list(G.successors(root)) # Use successors to get children for DiGraph
|
||||
if not isinstance(G, nx.DiGraph):
|
||||
if parent is not None:
|
||||
children.remove(parent)
|
||||
if len(children) != 0:
|
||||
dx = width / len(children)
|
||||
nextx = xcenter - width / 2 - dx / 2
|
||||
for child in children:
|
||||
nextx += dx
|
||||
pos = _hierarchy_pos(G, child, width=dx, vert_gap=vert_gap,
|
||||
vert_loc=vert_loc - vert_gap, xcenter=nextx,
|
||||
pos=pos, parent=root)
|
||||
return pos
|
||||
|
||||
return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter)
|
||||
|
||||
# Simplified JSON data for demonstration
|
||||
with open('gemla/test.json', 'r') as file:
|
||||
simplified_json_data = json.load(file)
|
||||
|
||||
# Function to traverse the tree and create a graph
|
||||
def traverse(node, graph, parent=None):
|
||||
if node is None:
|
||||
return
|
||||
|
||||
node_id = node["val"]["id"]
|
||||
if "node" in node["val"] and node["val"]["node"]:
|
||||
scores = node["val"]["node"]["scores"]
|
||||
generations = node["val"]["node"]["generation"]
|
||||
population_size = node["val"]["node"]["population_size"]
|
||||
# Prepare to track the highest score across all generations and the corresponding individual
|
||||
overall_max_score = float('-inf')
|
||||
overall_max_score_individual = None
|
||||
overall_max_score_gen = None
|
||||
|
||||
for gen, gen_scores in enumerate(scores):
|
||||
if gen_scores: # Ensure the dictionary is not empty
|
||||
# Find the max score and the individual for this generation
|
||||
max_score_for_gen = max(gen_scores.values())
|
||||
individual_with_max_score_for_gen = max(gen_scores, key=gen_scores.get)
|
||||
|
||||
# if max_score_for_gen > overall_max_score:
|
||||
overall_max_score = max_score_for_gen
|
||||
overall_max_score_individual = individual_with_max_score_for_gen
|
||||
overall_max_score_gen = gen
|
||||
|
||||
label = f"{node_id}\nGenerations: {generations}, Population: {population_size}\nMax score: {overall_max_score:.6f} (Individual {overall_max_score_individual} in Gen {overall_max_score_gen})"
|
||||
else:
|
||||
label = node_id
|
||||
|
||||
graph.add_node(node_id, label=label)
|
||||
if parent:
|
||||
graph.add_edge(parent, node_id)
|
||||
|
||||
traverse(node.get("left"), graph, parent=node_id)
|
||||
traverse(node.get("right"), graph, parent=node_id)
|
||||
|
||||
|
||||
# Create a directed graph
|
||||
G = nx.DiGraph()
|
||||
|
||||
# Populate the graph
|
||||
traverse(simplified_json_data[0], G)
|
||||
|
||||
# Find the root node (a node with no incoming edges)
|
||||
root_candidates = [node for node, indeg in G.in_degree() if indeg == 0]
|
||||
|
||||
if root_candidates:
|
||||
root_node = root_candidates[0] # Assuming there's only one root candidate
|
||||
else:
|
||||
root_node = None # This should ideally never happen in a properly structured tree
|
||||
|
||||
# Use the determined root node for hierarchy_pos
|
||||
if root_node is not None:
|
||||
pos = hierarchy_pos(G, root=root_node)
|
||||
labels = nx.get_node_attributes(G, 'label')
|
||||
nx.draw(G, pos, labels=labels, with_labels=True, arrows=True)
|
||||
plt.show()
|
||||
else:
|
||||
print("No root node found. Cannot draw the tree.")
|
|
@ -26,8 +26,8 @@ rand = "0.8.5"
|
|||
log = "0.4.21"
|
||||
env_logger = "0.11.3"
|
||||
futures = "0.3.30"
|
||||
smol = "2.0.0"
|
||||
smol-potat = "1.1.2"
|
||||
tokio = { version = "1.36.0", features = ["full"] }
|
||||
num_cpus = "1.16.0"
|
||||
easy-parallel = "3.3.1"
|
||||
fann = "0.1.8"
|
||||
async-trait = "0.1.78"
|
||||
|
|
|
@ -6,16 +6,17 @@ extern crate log;
|
|||
mod test_state;
|
||||
mod fighter_nn;
|
||||
|
||||
use easy_parallel::Parallel;
|
||||
use file_linked::constants::data_format::DataFormat;
|
||||
use gemla::{
|
||||
core::{Gemla, GemlaConfig},
|
||||
error::{log_error, Error},
|
||||
error::log_error,
|
||||
};
|
||||
use smol::{channel, channel::RecvError, future, Executor};
|
||||
use std::{path::PathBuf, time::Instant};
|
||||
use fighter_nn::FighterNN;
|
||||
use clap::Parser;
|
||||
use anyhow::Result;
|
||||
|
||||
// const NUM_THREADS: usize = 12;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(version, about, long_about = None)]
|
||||
|
@ -29,48 +30,40 @@ struct Args {
|
|||
///
|
||||
/// Use the -h, --h, or --help flag to see usage syntax.
|
||||
/// TODO
|
||||
fn main() -> anyhow::Result<()> {
|
||||
fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
info!("Starting");
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
// Obtainning number of threads to use
|
||||
let num_threads = num_cpus::get().max(1);
|
||||
let ex = Executor::new();
|
||||
let (signal, shutdown) = channel::unbounded::<()>();
|
||||
// Manually configure the Tokio runtime
|
||||
let runtime: Result<()> = tokio::runtime::Builder::new_multi_thread()
|
||||
.worker_threads(num_cpus::get())
|
||||
// .worker_threads(NUM_THREADS)
|
||||
.build()?
|
||||
.block_on(async {
|
||||
let args = Args::parse(); // Assuming Args::parse() doesn't need to be async
|
||||
let mut gemla = log_error(Gemla::<FighterNN>::new(
|
||||
&PathBuf::from(args.file),
|
||||
GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: false,
|
||||
},
|
||||
DataFormat::Json,
|
||||
))?;
|
||||
|
||||
// Create an executor thread pool.
|
||||
let (_, result): (Vec<Result<(), RecvError>>, Result<(), Error>) = Parallel::new()
|
||||
.each(0..num_threads, |_| {
|
||||
future::block_on(ex.run(shutdown.recv()))
|
||||
})
|
||||
.finish(|| {
|
||||
smol::block_on(async {
|
||||
drop(signal);
|
||||
// let gemla_arc = Arc::new(gemla);
|
||||
|
||||
// Command line arguments are parsed with the clap crate.
|
||||
let args = Args::parse();
|
||||
// Setup your application logic here
|
||||
// If `gemla::simulate` needs to run sequentially, simply call it in sequence without spawning new tasks
|
||||
|
||||
// Checking that the first argument <FILE> is a valid file
|
||||
let mut gemla = log_error(Gemla::<FighterNN>::new(
|
||||
&PathBuf::from(args.file),
|
||||
GemlaConfig {
|
||||
generations_per_height: 3,
|
||||
overwrite: false,
|
||||
},
|
||||
DataFormat::Json,
|
||||
))?;
|
||||
|
||||
loop {
|
||||
log_error(gemla.simulate(5).await)?;
|
||||
}
|
||||
})
|
||||
// Example placeholder loop to continuously run simulate
|
||||
loop { // Arbitrary loop count for demonstration
|
||||
gemla.simulate(5).await?;
|
||||
}
|
||||
});
|
||||
|
||||
result?;
|
||||
runtime?; // Handle errors from the block_on call
|
||||
|
||||
info!("Finished in {:?}", now.elapsed());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,7 +1,8 @@
|
|||
extern crate fann;
|
||||
|
||||
use std::{fs, path::PathBuf};
|
||||
use std::{fs::{self, File}, io::{self, BufRead, BufReader}, path::{Path, PathBuf}};
|
||||
use fann::{ActivationFunc, Fann};
|
||||
use futures::future::join_all;
|
||||
use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error};
|
||||
use rand::prelude::*;
|
||||
use rand::distributions::{Distribution, Uniform};
|
||||
|
@ -9,12 +10,15 @@ use serde::{Deserialize, Serialize};
|
|||
use anyhow::Context;
|
||||
use uuid::Uuid;
|
||||
use std::collections::HashMap;
|
||||
use tokio::process::Command;
|
||||
use async_trait::async_trait;
|
||||
|
||||
const BASE_DIR: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Simulations";
|
||||
const POPULATION: usize = 100;
|
||||
const POPULATION: usize = 50;
|
||||
const NEURAL_NETWORK_SHAPE: &[u32; 5] = &[14, 20, 20, 12, 8];
|
||||
const SIMULATION_ROUNDS: usize = 10;
|
||||
const SIMULATION_ROUNDS: usize = 5;
|
||||
const SURVIVAL_RATE: f32 = 0.5;
|
||||
const GAME_EXECUTABLE_PATH: &str = "F:\\\\vandomej\\Projects\\dootcamp-AI-Simulation\\Package\\Windows\\AI_Fight_Sim.exe";
|
||||
|
||||
// Here is the folder structure for the FighterNN:
|
||||
// base_dir/fighter_nn_{fighter_id}/{generation}/{fighter_id}_fighter_nn_{nn_id}.net
|
||||
|
@ -36,9 +40,10 @@ pub struct FighterNN {
|
|||
pub scores: Vec<HashMap<u64, f32>>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GeneticNode for FighterNN {
|
||||
// Check for the highest number of the folder name and increment it by 1
|
||||
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
let base_path = PathBuf::from(BASE_DIR);
|
||||
|
||||
let folder = base_path.join(format!("fighter_nn_{:06}", context.id));
|
||||
|
@ -57,6 +62,7 @@ impl GeneticNode for FighterNN {
|
|||
let nn = gen_folder.join(format!("{:06}_fighter_nn_{}.net", context.id, i));
|
||||
let mut fann = Fann::new(NEURAL_NETWORK_SHAPE)
|
||||
.with_context(|| "Failed to create nn")?;
|
||||
fann.randomize_weights(-0.8, 0.8);
|
||||
fann.set_activation_func_hidden(ActivationFunc::SigmoidSymmetric);
|
||||
fann.set_activation_func_output(ActivationFunc::SigmoidSymmetric);
|
||||
// This will overwrite any existing file with the same name
|
||||
|
@ -73,47 +79,89 @@ impl GeneticNode for FighterNN {
|
|||
}))
|
||||
}
|
||||
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
// For each nn in the current generation:
|
||||
for i in 0..self.population_size {
|
||||
// load the nn
|
||||
let nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, i));
|
||||
let fann = Fann::from_file(&nn)
|
||||
.with_context(|| format!("Failed to load nn"))?;
|
||||
|
||||
// Simulate the nn against the random nn
|
||||
let mut score = 0.0;
|
||||
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation
|
||||
let mut simulations = Vec::new();
|
||||
|
||||
// Using the same original nn, repeat the simulation with 5 random nn's from the current generation concurrently
|
||||
for _ in 0..SIMULATION_ROUNDS {
|
||||
let random_nn = self.folder.join(format!("{}", self.generation)).join(format!("{:06}_fighter_nn_{}.net", self.id, thread_rng().gen_range(0..self.population_size)));
|
||||
let random_fann = Fann::from_file(&random_nn)
|
||||
.with_context(|| format!("Failed to load random nn"))?;
|
||||
let random_nn_index = thread_rng().gen_range(0..self.population_size);
|
||||
let id = self.id.clone();
|
||||
let folder = self.folder.clone();
|
||||
let generation = self.generation;
|
||||
|
||||
let inputs: Vec<f32> = (0..NEURAL_NETWORK_SHAPE[0]).map(|_| thread_rng().gen_range(-1.0..1.0)).collect();
|
||||
let outputs = fann.run(&inputs)
|
||||
.with_context(|| format!("Failed to run nn"))?;
|
||||
let random_outputs = random_fann.run(&inputs)
|
||||
.with_context(|| format!("Failed to run random nn"))?;
|
||||
|
||||
// Average the difference between the outputs of the nn and random_nn and add the result to score
|
||||
let mut round_score = 0.0;
|
||||
for (o, r) in outputs.iter().zip(random_outputs.iter()) {
|
||||
round_score += o - r;
|
||||
}
|
||||
score += round_score / fann.get_num_output() as f32;
|
||||
let random_nn = folder.join(format!("{}", generation)).join(format!("{:06}_fighter_nn_{}.net", id, random_nn_index));
|
||||
let nn_clone = nn.clone(); // Clone the path to use in the async block
|
||||
|
||||
let config1_arg = format!("-NN1Config=\"{}\"", nn_clone.to_str().unwrap());
|
||||
let config2_arg = format!("-NN2Config=\"{}\"", random_nn.to_str().unwrap());
|
||||
let disable_unreal_rendering_arg = "-nullrhi".to_string();
|
||||
|
||||
let future = async move {
|
||||
// Construct the score file path
|
||||
let nn_id = format!("{:06}_fighter_nn_{}", id, i);
|
||||
let random_nn_id = format!("{:06}_fighter_nn_{}", id, random_nn_index);
|
||||
let score_file_name = format!("{}_vs_{}.txt", nn_id, random_nn_id);
|
||||
let score_file = folder.join(format!("{}", generation)).join(&score_file_name);
|
||||
|
||||
// Check if score file already exists before running the simulation
|
||||
if score_file.exists() {
|
||||
let round_score = read_score_from_file(&score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
return Ok::<f32, Error>(round_score);
|
||||
}
|
||||
|
||||
// Check if the opposite round score has been determined
|
||||
let opposite_score_file = folder.join(format!("{}", generation)).join(format!("{}_vs_{}.txt", random_nn_id, nn_id));
|
||||
if opposite_score_file.exists() {
|
||||
let round_score = read_score_from_file(&opposite_score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", opposite_score_file))?;
|
||||
return Ok::<f32, Error>(1.0 - round_score);
|
||||
}
|
||||
|
||||
if thread_rng().gen_range(0..100) < 4 {
|
||||
let _output = Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game");
|
||||
} else {
|
||||
let _output = Command::new(GAME_EXECUTABLE_PATH)
|
||||
.arg(&config1_arg)
|
||||
.arg(&config2_arg)
|
||||
.arg(&disable_unreal_rendering_arg)
|
||||
.output()
|
||||
.await
|
||||
.expect("Failed to execute game");
|
||||
}
|
||||
|
||||
// Read the score from the file
|
||||
let round_score = read_score_from_file(&score_file, &nn_id)
|
||||
.with_context(|| format!("Failed to read score from file: {:?}", score_file_name))?;
|
||||
|
||||
Ok::<f32, Error>(round_score)
|
||||
};
|
||||
|
||||
simulations.push(future);
|
||||
}
|
||||
|
||||
score /= 5.0;
|
||||
|
||||
// Wait for all simulation rounds to complete
|
||||
let results: Result<Vec<f32>, Error> = join_all(simulations).await.into_iter().collect();
|
||||
|
||||
let score = results?.into_iter().sum::<f32>() / SIMULATION_ROUNDS as f32;
|
||||
trace!("NN {:06}_fighter_nn_{} scored {}", self.id, i, score);
|
||||
self.scores[self.generation as usize].insert(i as u64, score);
|
||||
}
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
let survivor_count = (self.population_size as f32 * SURVIVAL_RATE) as usize;
|
||||
|
||||
// Create the new generation folder
|
||||
|
@ -234,3 +282,23 @@ impl GeneticNode for FighterNN {
|
|||
}))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_score_from_file(file_path: &Path, nn_id: &str) -> Result<f32, io::Error> {
|
||||
let file = File::open(file_path)?;
|
||||
let reader = BufReader::new(file);
|
||||
|
||||
for line in reader.lines() {
|
||||
let line = line?;
|
||||
if line.starts_with(nn_id) {
|
||||
let parts: Vec<&str> = line.split(':').collect();
|
||||
if parts.len() == 2 {
|
||||
return parts[1].trim().parse::<f32>().map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(io::Error::new(
|
||||
io::ErrorKind::NotFound,
|
||||
"NN ID not found in scores file",
|
||||
))
|
||||
}
|
|
@ -2,6 +2,7 @@ use gemla::{core::genetic_node::{GeneticNode, GeneticNodeContext}, error::Error}
|
|||
use rand::prelude::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
use async_trait::async_trait;
|
||||
|
||||
const POPULATION_SIZE: u64 = 5;
|
||||
const POPULATION_REDUCTION_SIZE: u64 = 3;
|
||||
|
@ -11,8 +12,9 @@ pub struct TestState {
|
|||
pub population: Vec<i64>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GeneticNode for TestState {
|
||||
fn initialize(_context: &GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<Self>, Error> {
|
||||
let mut population: Vec<i64> = vec![];
|
||||
|
||||
for _ in 0..POPULATION_SIZE {
|
||||
|
@ -22,7 +24,7 @@ impl GeneticNode for TestState {
|
|||
Ok(Box::new(TestState { population }))
|
||||
}
|
||||
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
self.population = self
|
||||
|
@ -34,7 +36,7 @@ impl GeneticNode for TestState {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
let mut rng = thread_rng();
|
||||
|
||||
let mut v = self.population.clone();
|
||||
|
@ -83,7 +85,7 @@ impl GeneticNode for TestState {
|
|||
|
||||
let mut result = TestState { population: v };
|
||||
|
||||
result.mutate(&GeneticNodeContext {
|
||||
result.mutate(GeneticNodeContext {
|
||||
id: id.clone(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
|
@ -101,7 +103,7 @@ mod tests {
|
|||
#[test]
|
||||
fn test_initialize() {
|
||||
let state = TestState::initialize(
|
||||
&GeneticNodeContext {
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
|
@ -111,8 +113,8 @@ mod tests {
|
|||
assert_eq!(state.population.len(), POPULATION_SIZE as usize);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simulate() {
|
||||
#[tokio::test]
|
||||
async fn test_simulate() {
|
||||
let mut state = TestState {
|
||||
population: vec![1, 1, 2, 3],
|
||||
};
|
||||
|
@ -120,31 +122,31 @@ mod tests {
|
|||
let original_population = state.population.clone();
|
||||
|
||||
state.simulate(
|
||||
&GeneticNodeContext {
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
).await.unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
.all(|(&a, &b)| b >= a - 1 && b <= a + 2));
|
||||
|
||||
state.simulate(
|
||||
&GeneticNodeContext {
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
).await.unwrap();
|
||||
state.simulate(
|
||||
&GeneticNodeContext {
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
}
|
||||
).unwrap();
|
||||
).await.unwrap();
|
||||
assert!(original_population
|
||||
.iter()
|
||||
.zip(state.population.iter())
|
||||
|
@ -158,7 +160,7 @@ mod tests {
|
|||
};
|
||||
|
||||
state.mutate(
|
||||
&GeneticNodeContext {
|
||||
GeneticNodeContext {
|
||||
id: Uuid::new_v4(),
|
||||
generation: 0,
|
||||
max_generations: 0,
|
||||
|
|
|
@ -8,6 +8,7 @@ use anyhow::Context;
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Debug;
|
||||
use uuid::Uuid;
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// An enum used to control the state of a [`GeneticNode`]
|
||||
///
|
||||
|
@ -24,6 +25,7 @@ pub enum GeneticState {
|
|||
Finish,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct GeneticNodeContext {
|
||||
pub generation: u64,
|
||||
pub max_generations: u64,
|
||||
|
@ -33,20 +35,21 @@ pub struct GeneticNodeContext {
|
|||
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
|
||||
///
|
||||
/// [`Bracket`]: crate::bracket::Bracket
|
||||
pub trait GeneticNode {
|
||||
#[async_trait]
|
||||
pub trait GeneticNode: Send {
|
||||
/// Initializes a new instance of a [`GeneticState`].
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn initialize(context: &GeneticNodeContext) -> Result<Box<Self>, Error>;
|
||||
fn initialize(context: GeneticNodeContext) -> Result<Box<Self>, Error>;
|
||||
|
||||
fn simulate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
|
||||
async fn simulate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
|
||||
|
||||
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
|
||||
///
|
||||
/// # Examples
|
||||
/// TODO
|
||||
fn mutate(&mut self, context: &GeneticNodeContext) -> Result<(), Error>;
|
||||
fn mutate(&mut self, context: GeneticNodeContext) -> Result<(), Error>;
|
||||
|
||||
fn merge(left: &Self, right: &Self, id: &Uuid) -> Result<Box<Self>, Error>;
|
||||
}
|
||||
|
@ -76,7 +79,7 @@ impl<T> Default for GeneticNodeWrapper<T> {
|
|||
|
||||
impl<T> GeneticNodeWrapper<T>
|
||||
where
|
||||
T: GeneticNode + Debug,
|
||||
T: GeneticNode + Debug + Send,
|
||||
{
|
||||
pub fn new(max_generations: u64) -> Self {
|
||||
GeneticNodeWrapper::<T> {
|
||||
|
@ -115,7 +118,7 @@ where
|
|||
self.state
|
||||
}
|
||||
|
||||
pub fn process_node(&mut self) -> Result<GeneticState, Error> {
|
||||
pub async fn process_node(&mut self) -> Result<GeneticState, Error> {
|
||||
let context = GeneticNodeContext {
|
||||
generation: self.generation,
|
||||
max_generations: self.max_generations,
|
||||
|
@ -124,11 +127,11 @@ where
|
|||
|
||||
match (self.state, &mut self.node) {
|
||||
(GeneticState::Initialize, _) => {
|
||||
self.node = Some(*T::initialize(&context)?);
|
||||
self.node = Some(*T::initialize(context.clone())?);
|
||||
self.state = GeneticState::Simulate;
|
||||
}
|
||||
(GeneticState::Simulate, Some(n)) => {
|
||||
n.simulate(&context)
|
||||
n.simulate(context.clone()).await
|
||||
.with_context(|| format!("Error simulating node: {:?}", self))?;
|
||||
|
||||
self.state = if self.generation >= self.max_generations {
|
||||
|
@ -138,7 +141,7 @@ where
|
|||
};
|
||||
}
|
||||
(GeneticState::Mutate, Some(n)) => {
|
||||
n.mutate(&context)
|
||||
n.mutate(context.clone())
|
||||
.with_context(|| format!("Error mutating node: {:?}", self))?;
|
||||
|
||||
self.generation += 1;
|
||||
|
@ -157,23 +160,25 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::error::Error;
|
||||
use anyhow::anyhow;
|
||||
use async_trait::async_trait;
|
||||
|
||||
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
|
||||
struct TestState {
|
||||
pub score: f64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl GeneticNode for TestState {
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
|
@ -270,16 +275,16 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_process_node() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_process_node() -> Result<(), Error> {
|
||||
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
|
||||
|
||||
assert_eq!(genetic_node.state(), GeneticState::Initialize);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node()?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Mutate);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Simulate);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
|
||||
assert_eq!(genetic_node.process_node().await?, GeneticState::Finish);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -5,10 +5,11 @@ pub mod genetic_node;
|
|||
|
||||
use crate::{error::Error, tree::Tree};
|
||||
use file_linked::{constants::data_format::DataFormat, FileLinked};
|
||||
use futures::{future, future::BoxFuture};
|
||||
use futures::future;
|
||||
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
|
||||
use log::{info, trace, warn};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use tokio::task::JoinHandle;
|
||||
use std::{
|
||||
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
|
||||
time::Instant,
|
||||
|
@ -65,15 +66,15 @@ pub struct GemlaConfig {
|
|||
/// individuals.
|
||||
///
|
||||
/// [`GeneticNode`]: genetic_node::GeneticNode
|
||||
pub struct Gemla<'a, T>
|
||||
pub struct Gemla<T>
|
||||
where
|
||||
T: Serialize + Clone,
|
||||
{
|
||||
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig)>,
|
||||
threads: HashMap<Uuid, BoxFuture<'a, Result<GeneticNodeWrapper<T>, Error>>>,
|
||||
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
|
||||
}
|
||||
|
||||
impl<'a, T: 'a> Gemla<'a, T>
|
||||
impl<T: 'static> Gemla<T>
|
||||
where
|
||||
T: GeneticNode + Serialize + DeserializeOwned + Debug + Clone + Send,
|
||||
{
|
||||
|
@ -147,7 +148,9 @@ where
|
|||
trace!("Adding node to process list {}", node.id());
|
||||
|
||||
self.threads
|
||||
.insert(node.id(), Box::pin(Gemla::process_node(node)));
|
||||
.insert(node.id(), tokio::spawn(async move {
|
||||
Gemla::process_node(node).await
|
||||
}));
|
||||
} else {
|
||||
trace!("No node found to process, joining threads");
|
||||
|
||||
|
@ -163,9 +166,10 @@ where
|
|||
trace!("Joining threads for nodes {:?}", self.threads.keys());
|
||||
|
||||
let results = future::join_all(self.threads.values_mut()).await;
|
||||
|
||||
// Converting a list of results into a result wrapping the list
|
||||
let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> =
|
||||
results.into_iter().collect();
|
||||
results.into_iter().flatten().collect();
|
||||
self.threads.clear();
|
||||
|
||||
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
|
||||
|
@ -323,7 +327,12 @@ where
|
|||
let node_state_time = Instant::now();
|
||||
let node_state = node.state();
|
||||
|
||||
node.process_node()?;
|
||||
node.process_node().await?;
|
||||
|
||||
if node.state() == GeneticState::Simulate
|
||||
{
|
||||
node.process_node().await?;
|
||||
}
|
||||
|
||||
trace!(
|
||||
"{:?} completed in {:?} for {}",
|
||||
|
@ -346,6 +355,8 @@ mod tests {
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
use std::fs;
|
||||
use async_trait::async_trait;
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
use self::genetic_node::GeneticNodeContext;
|
||||
|
||||
|
@ -378,17 +389,18 @@ mod tests {
|
|||
pub score: f64,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl genetic_node::GeneticNode for TestState {
|
||||
fn simulate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
async fn simulate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
self.score += 1.0;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn mutate(&mut self, _context: &GeneticNodeContext) -> Result<(), Error> {
|
||||
fn mutate(&mut self, _context: GeneticNodeContext) -> Result<(), Error> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn initialize(_context: &GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
fn initialize(_context: GeneticNodeContext) -> Result<Box<TestState>, Error> {
|
||||
Ok(Box::new(TestState { score: 0.0 }))
|
||||
}
|
||||
|
||||
|
@ -401,66 +413,84 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_new() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_new() -> Result<(), Error> {
|
||||
let path = PathBuf::from("test_new_non_existing");
|
||||
CleanUp::new(&path).run(|p| {
|
||||
assert!(!path.exists());
|
||||
// Use `spawn_blocking` to run synchronous code that needs to call async code internally.
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
|
||||
CleanUp::new(&path).run(move |p| {
|
||||
rt.block_on(async {
|
||||
assert!(!path.exists());
|
||||
|
||||
// Testing initial creation
|
||||
let mut config = GemlaConfig {
|
||||
generations_per_height: 1,
|
||||
overwrite: true
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
// Testing initial creation
|
||||
let mut config = GemlaConfig {
|
||||
generations_per_height: 1,
|
||||
overwrite: true,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
smol::block_on(gemla.simulate(2))?;
|
||||
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
|
||||
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
// Now we can use `.await` within the spawned blocking task.
|
||||
gemla.simulate(2).await?;
|
||||
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
|
||||
|
||||
// Testing overwriting data
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
|
||||
smol::block_on(gemla.simulate(2))?;
|
||||
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
|
||||
// Testing overwriting data
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
gemla.simulate(2).await?;
|
||||
assert_eq!(gemla.data.readonly().0.as_ref().unwrap().height(), 2);
|
||||
|
||||
// Testing not-overwriting data
|
||||
config.overwrite = false;
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
|
||||
smol::block_on(gemla.simulate(2))?;
|
||||
assert_eq!(gemla.tree_ref().unwrap().height(), 4);
|
||||
// Testing not-overwriting data
|
||||
config.overwrite = false;
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
gemla.simulate(2).await?;
|
||||
assert_eq!(gemla.tree_ref().unwrap().height(), 4);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
drop(gemla);
|
||||
assert!(path.exists());
|
||||
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simulate() -> Result<(), Error> {
|
||||
#[tokio::test]
|
||||
async fn test_simulate() -> Result<(), Error> {
|
||||
let path = PathBuf::from("test_simulate");
|
||||
CleanUp::new(&path).run(|p| {
|
||||
// Testing initial creation
|
||||
let config = GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: true
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
|
||||
tokio::task::spawn_blocking(move || {
|
||||
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
|
||||
CleanUp::new(&path).run(move |p| {
|
||||
rt.block_on(async {
|
||||
// Testing initial creation
|
||||
let config = GemlaConfig {
|
||||
generations_per_height: 10,
|
||||
overwrite: true,
|
||||
};
|
||||
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json)?;
|
||||
|
||||
smol::block_on(gemla.simulate(5))?;
|
||||
let tree = gemla.tree_ref().unwrap();
|
||||
assert_eq!(tree.height(), 5);
|
||||
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
|
||||
// Now we can use `.await` within the spawned blocking task.
|
||||
gemla.simulate(5).await?;
|
||||
let tree = gemla.tree_ref().unwrap();
|
||||
assert_eq!(tree.height(), 5);
|
||||
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
|
||||
|
||||
Ok(())
|
||||
})
|
||||
Ok(())
|
||||
})
|
||||
})
|
||||
}).await.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue