GEMLA/gemla/src/core/mod.rs
2025-09-08 16:40:57 -07:00

637 lines
24 KiB
Rust

//! Simulates a genetic algorithm on a population in order to improve the fit score and performance. The simulations
//! are performed in a tournament bracket configuration so that populations can compete against each other.
pub mod genetic_node;
use crate::{error::Error, tree::Tree};
use async_recursion::async_recursion;
use file_linked::{constants::data_format::DataFormat, FileLinked};
use futures::future;
use genetic_node::{GeneticNode, GeneticNodeWrapper, GeneticState};
use log::{info, trace, warn};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
collections::HashMap, fmt::Debug, fs::File, io::ErrorKind, marker::Send, mem, path::Path,
sync::Arc, time::Instant,
};
use tokio::{sync::RwLock, task::JoinHandle};
use uuid::Uuid;
type SimulationTree<T> = Box<Tree<GeneticNodeWrapper<T>>>;
/// Provides configuration options for managing a [`Gemla`] object as it executes.
///
/// # Examples
/// ```rust,ignore
/// #[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
/// struct TestState {
/// pub score: f64,
/// }
///
/// impl genetic_node::GeneticNode for TestState {
/// fn simulate(&mut self) -> Result<(), Error> {
/// self.score += 1.0;
/// Ok(())
/// }
///
/// fn mutate(&mut self) -> Result<(), Error> {
/// Ok(())
/// }
///
/// fn initialize() -> Result<Box<TestState>, Error> {
/// Ok(Box::new(TestState { score: 0.0 }))
/// }
///
/// fn merge(left: &TestState, right: &TestState) -> Result<Box<TestState>, Error> {
/// Ok(Box::new(if left.score > right.score {
/// left.clone()
/// } else {
/// right.clone()
/// }))
/// }
/// }
///
/// fn main() {
///
/// }
/// ```
#[derive(Serialize, Deserialize, Copy, Clone)]
pub struct GemlaConfig {
pub overwrite: bool,
}
/// Creates a tournament style bracket for simulating and evaluating nodes of type `T` implementing [`GeneticNode`].
/// These nodes are built upwards as a balanced binary tree starting from the bottom. This results in `Bracket` building
/// a separate tree of the same height then merging trees together. Evaluating populations between nodes and taking the strongest
/// individuals.
///
/// [`GeneticNode`]: genetic_node::GeneticNode
pub struct Gemla<T>
where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub data: FileLinked<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>,
threads: HashMap<Uuid, JoinHandle<Result<GeneticNodeWrapper<T>, Error>>>,
}
impl<T: 'static> Gemla<T>
where
T: GeneticNode + Serialize + DeserializeOwned + Debug + Send + Sync + Clone,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub async fn new(
path: &Path,
config: GemlaConfig,
data_format: DataFormat,
) -> Result<Self, Error> {
match File::open(path) {
// If the file exists we either want to overwrite the file or read from the file
// based on the configuration provided
Ok(_) => Ok(Gemla {
data: if config.overwrite {
FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?
} else {
FileLinked::from_file(path, data_format)?
},
threads: HashMap::new(),
}),
// If the file doesn't exist we must create it
Err(error) if error.kind() == ErrorKind::NotFound => Ok(Gemla {
data: FileLinked::new((None, config, T::Context::default()), path, data_format)
.await?,
threads: HashMap::new(),
}),
Err(error) => Err(Error::IO(error)),
}
}
pub fn tree_ref(&self) -> Arc<RwLock<(Option<SimulationTree<T>>, GemlaConfig, T::Context)>> {
self.data.readonly().clone()
}
pub async fn simulate(&mut self, steps: u64) -> Result<(), Error> {
let tree_completed = {
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
tree_ref.is_none() || tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(true)
};
if tree_completed {
// Before we can process nodes we must create blank nodes in their place to keep track of which nodes have been processed
// in the tree and which nodes have not.
self.data
.mutate(|(d, _, _)| {
let mut tree: Option<SimulationTree<T>> =
Gemla::increase_height(d.take(), steps);
mem::swap(d, &mut tree);
})
.await?;
}
{
// Only increase height if the tree is uninitialized or completed
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
info!(
"Height of simulation tree increased to {}",
tree_ref
.map(|t| format!("{}", t.height()))
.unwrap_or_else(|| "Tree is not defined".to_string())
);
}
loop {
let is_tree_processed;
{
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let tree_ref = data_ref.0.as_ref();
is_tree_processed = tree_ref.map(|t| Gemla::is_completed(t)).unwrap_or(false)
}
// We need to keep simulating until the tree has been completely processed.
if is_tree_processed {
self.join_threads().await?;
info!("Processed tree");
break;
}
let (node, gemla_context) = {
let data_arc = self.data.readonly();
let data_ref = data_arc.read().await;
let (tree_ref, _, gemla_context) = &*data_ref; // (Option<Box<Tree<GeneticNodeWrapper<T>>>, GemlaConfig, T::Context)
let node = tree_ref.as_ref().and_then(|t| self.get_unprocessed_node(t));
(node, gemla_context.clone())
};
if let Some(node) = node {
trace!("Adding node to process list {}", node.id());
let gemla_context = gemla_context.clone();
self.threads.insert(
node.id(),
tokio::spawn(async move { Gemla::process_node(node, gemla_context).await }),
);
} else {
trace!("No node found to process, joining threads");
self.join_threads().await?;
}
}
Ok(())
}
async fn join_threads(&mut self) -> Result<(), Error> {
if !self.threads.is_empty() {
trace!("Joining threads for nodes {:?}", self.threads.keys());
let results = future::join_all(self.threads.values_mut()).await;
// Converting a list of results into a result wrapping the list
let reduced_results: Result<Vec<GeneticNodeWrapper<T>>, Error> =
results.into_iter().flatten().collect();
self.threads.clear();
// We need to retrieve the processed nodes from the resulting list and replace them in the original list
match reduced_results {
Ok(r) => {
self.data
.mutate_async(|d| async move {
// Scope to limit the duration of the read lock
let (_, context) = {
let data_read = d.read().await;
(data_read.1, data_read.2.clone())
}; // Read lock is dropped here
let mut data_write = d.write().await;
if let Some(t) = data_write.0.as_mut() {
let failed_nodes = Gemla::replace_nodes(t, r);
// We receive a list of nodes that were unable to be found in the original tree
if !failed_nodes.is_empty() {
warn!(
"Unable to find {:?} to replace in tree",
failed_nodes.iter().map(|n| n.id())
)
}
// Once the nodes are replaced we need to find nodes that can be merged from the completed children nodes
Gemla::merge_completed_nodes(t, context.clone()).await
} else {
warn!("Unable to replce nodes {:?} in empty tree", r);
Ok(())
}
})
.await??;
}
Err(e) => return Err(e),
}
}
Ok(())
}
#[async_recursion]
async fn merge_completed_nodes<'a>(
tree: &'a mut SimulationTree<T>,
gemla_context: T::Context,
) -> Result<(), Error> {
if tree.val.state() == GeneticState::Initialize {
match (&mut tree.left, &mut tree.right) {
// If the current node has been initialized, and has children nodes that are completed, then we need
// to merge the children nodes together into the parent node
(Some(l), Some(r))
if l.val.state() == GeneticState::Finish
&& r.val.state() == GeneticState::Finish =>
{
info!("Merging nodes {} and {}", l.val.id(), r.val.id());
if let (Some(left_node), Some(right_node)) = (l.val.as_ref(), r.val.as_ref()) {
let merged_node = GeneticNode::merge(
left_node,
right_node,
&tree.val.id(),
gemla_context.clone(),
)
.await?;
tree.val = GeneticNodeWrapper::from(*merged_node, tree.val.id());
}
}
(Some(l), Some(r)) => {
Gemla::merge_completed_nodes(l, gemla_context.clone()).await?;
Gemla::merge_completed_nodes(r, gemla_context.clone()).await?;
}
// If there is only one child node that's completed then we want to copy it to the parent node
(Some(l), None) if l.val.state() == GeneticState::Finish => {
trace!("Copying node {}", l.val.id());
if let Some(left_node) = l.val.as_ref() {
GeneticNodeWrapper::from(left_node.clone(), tree.val.id());
}
}
(Some(l), None) => Gemla::merge_completed_nodes(l, gemla_context.clone()).await?,
(None, Some(r)) if r.val.state() == GeneticState::Finish => {
trace!("Copying node {}", r.val.id());
if let Some(right_node) = r.val.as_ref() {
tree.val = GeneticNodeWrapper::from(right_node.clone(), tree.val.id());
}
}
(None, Some(r)) => Gemla::merge_completed_nodes(r, gemla_context.clone()).await?,
(_, _) => (),
}
}
Ok(())
}
fn get_unprocessed_node(&self, tree: &SimulationTree<T>) -> Option<GeneticNodeWrapper<T>> {
// If the current node has been processed or exists in the thread list then we want to stop recursing. Checking if it exists in the thread list
// should be fine because we process the tree from bottom to top.
if tree.val.state() != GeneticState::Finish && !self.threads.contains_key(&tree.val.id()) {
match (&tree.left, &tree.right) {
// If the children are finished we can start processing the currrent node. The current node should be merged from the children already
// during join_threads.
(Some(l), Some(r))
if l.val.state() == GeneticState::Finish
&& r.val.state() == GeneticState::Finish =>
{
Some(tree.val.clone())
}
(Some(l), Some(r)) => self
.get_unprocessed_node(l)
.or_else(|| self.get_unprocessed_node(r)),
(Some(l), None) => self.get_unprocessed_node(l),
(None, Some(r)) => self.get_unprocessed_node(r),
(None, None) => Some(tree.val.clone()),
}
} else {
None
}
}
fn replace_nodes(
tree: &mut SimulationTree<T>,
mut nodes: Vec<GeneticNodeWrapper<T>>,
) -> Vec<GeneticNodeWrapper<T>> {
// Replacing nodes as we recurse through the tree
if let Some(i) = nodes.iter().position(|n| n.id() == tree.val.id()) {
tree.val = nodes.remove(i);
}
match (&mut tree.left, &mut tree.right) {
(Some(l), Some(r)) => Gemla::replace_nodes(r, Gemla::replace_nodes(l, nodes)),
(Some(l), None) => Gemla::replace_nodes(l, nodes),
(None, Some(r)) => Gemla::replace_nodes(r, nodes),
_ => nodes,
}
}
fn increase_height(tree: Option<SimulationTree<T>>, amount: u64) -> Option<SimulationTree<T>> {
if amount == 0 {
tree
} else {
let left_branch_height =
tree.as_ref().map(|t| t.height() as u64).unwrap_or(0) + amount - 1;
Some(Box::new(Tree::new(
GeneticNodeWrapper::new(),
Gemla::increase_height(tree, amount - 1),
// The right branch height has to equal the left branches total height
if left_branch_height > 0 {
Some(Box::new(btree!(GeneticNodeWrapper::new())))
} else {
None
},
)))
}
}
fn is_completed(tree: &SimulationTree<T>) -> bool {
// If the current node is finished, then by convention the children should all be finished as well
tree.val.state() == GeneticState::Finish
}
async fn process_node(
mut node: GeneticNodeWrapper<T>,
gemla_context: T::Context,
) -> Result<GeneticNodeWrapper<T>, Error> {
let node_state_time = Instant::now();
let node_state = node.state();
node.process_node(gemla_context.clone()).await?;
info!(
"{:?} completed in {:?} for {}",
node_state,
node_state_time.elapsed(),
node.id()
);
if node.state() == GeneticState::Finish {
info!("Processed node {}", node.id());
}
Ok(node)
}
}
#[cfg(test)]
mod tests {
use crate::core::*;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use tokio::runtime::Runtime;
use self::genetic_node::GeneticNodeContext;
struct CleanUp {
path: PathBuf,
}
impl CleanUp {
fn new(path: &Path) -> CleanUp {
CleanUp {
path: path.to_path_buf(),
}
}
pub fn run<F: FnOnce(&Path) -> Result<(), Error>>(&self, op: F) -> Result<(), Error> {
op(&self.path)
}
}
impl Drop for CleanUp {
fn drop(&mut self) {
if self.path.exists() {
fs::remove_file(&self.path).expect("Unable to remove file");
}
}
}
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState {
pub score: f64,
pub max_generations: u64,
}
#[async_trait]
impl genetic_node::GeneticNode for TestState {
type Context = ();
async fn simulate(
&mut self,
context: GeneticNodeContext<Self::Context>,
) -> Result<bool, Error> {
self.score += 1.0;
Ok(context.generation < self.max_generations)
}
async fn mutate(
&mut self,
_context: GeneticNodeContext<Self::Context>,
) -> Result<(), Error> {
Ok(())
}
async fn initialize(
_context: GeneticNodeContext<Self::Context>,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState {
score: 0.0,
max_generations: 10,
}))
}
async fn merge(
left: &TestState,
right: &TestState,
_id: &Uuid,
_: Self::Context,
) -> Result<Box<TestState>, Error> {
Ok(Box::new(if left.score > right.score {
left.clone()
} else {
right.clone()
}))
}
}
#[tokio::test]
async fn test_new() -> Result<(), Error> {
let path = PathBuf::from("test_new_non_existing");
// Use `spawn_blocking` to run synchronous code that needs to call async code internally.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
assert!(!path.exists());
// Testing initial creation
let mut config = GemlaConfig { overwrite: true };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(2).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
drop(data_lock);
drop(gemla);
assert!(path.exists());
// Testing overwriting data
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
gemla.simulate(2).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
assert_eq!(data_lock.0.as_ref().unwrap().height(), 2);
drop(data_lock);
drop(gemla);
assert!(path.exists());
// Testing not-overwriting data
config.overwrite = false;
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
gemla.simulate(2).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 4);
drop(data_lock);
drop(gemla);
assert!(path.exists());
Ok(())
})
})
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
#[tokio::test]
async fn test_simulate() -> Result<(), Error> {
let path = PathBuf::from("test_simulate");
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
tokio::task::spawn_blocking(move || {
let rt = Runtime::new().unwrap(); // Create a new Tokio runtime for the async block.
CleanUp::new(&path).run(move |p| {
rt.block_on(async {
// Testing initial creation
let config = GemlaConfig { overwrite: true };
let mut gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
// Now we can use `.await` within the spawned blocking task.
gemla.simulate(5).await?;
let data = gemla.data.readonly();
let data_lock = data.read().await;
let tree = data_lock.0.as_ref().unwrap();
assert_eq!(tree.height(), 5);
assert_eq!(tree.val.as_ref().unwrap().score, 50.0);
Ok(())
})
})
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
#[tokio::test]
async fn test_tree_ref_returns_arc_rwlock() -> Result<(), Error> {
use std::path::PathBuf;
let path = PathBuf::from("test_tree_ref");
// Use `spawn_blocking` to run the synchronous closure that internally awaits async code.
tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Runtime::new().unwrap();
super::tests::CleanUp::new(&path).run(move |p| {
rt.block_on(async {
let config = GemlaConfig { overwrite: true };
let gemla = Gemla::<TestState>::new(&p, config, DataFormat::Json).await?;
let arc_rwlock = gemla.tree_ref();
// Check that the returned value is an Arc<RwLock<...>>
let data = arc_rwlock.read().await;
assert!(data.0.is_none());
assert_eq!(data.1.overwrite, true);
Ok(())
})
})
})
.await
.unwrap()?; // Wait for the blocking task to complete, then handle the Result.
Ok(())
}
#[tokio::test]
async fn test_merge_completed_nodes_merges_children() -> Result<(), Error> {
// Create two finished child nodes with different scores
let left_state = TestState { score: 10.0, max_generations: 10 };
let right_state = TestState { score: 20.0, max_generations: 10 };
let left_id = Uuid::new_v4();
let right_id = Uuid::new_v4();
// Use the public constructor
let mut left_node = GeneticNodeWrapper::from(left_state.clone(), left_id);
let mut right_node = GeneticNodeWrapper::from(right_state.clone(), right_id);
// Set state to Finish using the process_node method (simulate processing)
while left_node.state() != GeneticState::Finish {
left_node.process_node(()).await?;
}
while right_node.state() != GeneticState::Finish {
right_node.process_node(()).await?;
}
// Create tree nodes
let left = Box::new(crate::tree::Tree::new(left_node, None, None));
let right = Box::new(crate::tree::Tree::new(right_node, None, None));
// Use GeneticNodeWrapper::new() for the parent node (node: None, state: Initialize)
let parent_node = GeneticNodeWrapper::new();
// Set the correct id for the parent node (if needed for your logic)
// If you need to set the id, you may need to add a setter or constructor for testing.
let mut tree = Box::new(crate::tree::Tree::new(parent_node, Some(left), Some(right)));
// Call merge_completed_nodes
Gemla::<TestState>::merge_completed_nodes(&mut tree, ()).await?;
// After merging, parent should contain the right_node's data (since right.score > left.score)
let merged = &tree.val;
let merged_score = merged.as_ref().unwrap().score;
// Score plus max generations because simulate adds 1.0 to score for each generation
assert_eq!(merged_score, 30.0);
Ok(())
}
}