GEMLA/gemla/src/core/genetic_node.rs
2024-04-04 18:40:17 -07:00

298 lines
8.6 KiB
Rust

//! A trait used to interact with the internal state of nodes within the [`Bracket`]
//!
//! [`Bracket`]: crate::bracket::Bracket
use crate::error::Error;
use anyhow::Context;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::fmt::Debug;
use uuid::Uuid;
use async_trait::async_trait;
/// An enum used to control the state of a [`GeneticNode`]
///
/// [`GeneticNode`]: crate::bracket::genetic_node
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Copy)]
pub enum GeneticState {
/// The node and it's data have not finished initializing
Initialize,
/// The node is currently simulating a round against target data to determine the fitness of the population
Simulate,
/// The node is currently mutating members of it's population and breeding new members
Mutate,
/// The node has finished processing for a given number of iterations
Finish,
}
#[derive(Clone, Debug)]
pub struct GeneticNodeContext<S> {
pub generation: u64,
pub max_generations: u64,
pub id: Uuid,
pub gemla_context: S
}
/// A trait used to interact with the internal state of nodes within the [`Bracket`]
///
/// [`Bracket`]: crate::bracket::Bracket
#[async_trait]
pub trait GeneticNode : Send {
type Context;
/// Initializes a new instance of a [`GeneticState`].
///
/// # Examples
/// TODO
async fn initialize(context: GeneticNodeContext<Self::Context>) -> Result<Box<Self>, Error>;
async fn simulate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
/// Mutates members in a population and/or crossbreeds them to produce new offspring.
///
/// # Examples
/// TODO
async fn mutate(&mut self, context: GeneticNodeContext<Self::Context>) -> Result<(), Error>;
async fn merge(left: &Self, right: &Self, id: &Uuid, context: Self::Context) -> Result<Box<Self>, Error>;
}
/// Used externally to wrap a node implementing the [`GeneticNode`] trait. Processes state transitions for the given node as
/// well as signal recovery. Transition states are given by [`GeneticState`]
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub struct GeneticNodeWrapper<T> {
node: Option<T>,
state: GeneticState,
generation: u64,
max_generations: u64,
id: Uuid,
}
impl<T> Default for GeneticNodeWrapper<T> {
fn default() -> Self {
GeneticNodeWrapper {
node: None,
state: GeneticState::Initialize,
generation: 1,
max_generations: 1,
id: Uuid::new_v4(),
}
}
}
impl<T> GeneticNodeWrapper<T>
where
T: GeneticNode + Debug + Send,
T::Context: Send + Sync + Clone + Debug + Serialize + DeserializeOwned + 'static + Default,
{
pub fn new(max_generations: u64) -> Self {
GeneticNodeWrapper::<T> {
max_generations,
..Default::default()
}
}
pub fn from(data: T, max_generations: u64, id: Uuid) -> Self {
GeneticNodeWrapper {
node: Some(data),
state: GeneticState::Simulate,
generation: 1,
max_generations,
id,
}
}
pub fn as_ref(&self) -> Option<&T> {
self.node.as_ref()
}
pub fn id(&self) -> Uuid {
self.id
}
pub fn max_generations(&self) -> u64 {
self.max_generations
}
pub fn generation(&self) -> u64 {
self.generation
}
pub fn state(&self) -> GeneticState {
self.state
}
pub async fn process_node(&mut self, gemla_context: T::Context) -> Result<GeneticState, Error> {
let context = GeneticNodeContext {
generation: self.generation,
max_generations: self.max_generations,
id: self.id,
gemla_context,
};
match (self.state, &mut self.node) {
(GeneticState::Initialize, _) => {
self.node = Some(*T::initialize(context.clone()).await?);
self.state = GeneticState::Simulate;
}
(GeneticState::Simulate, Some(n)) => {
n.simulate(context.clone()).await
.with_context(|| format!("Error simulating node: {:?}", self))?;
self.state = if self.generation >= self.max_generations {
GeneticState::Finish
} else {
GeneticState::Mutate
};
}
(GeneticState::Mutate, Some(n)) => {
n.mutate(context.clone()).await
.with_context(|| format!("Error mutating node: {:?}", self))?;
self.generation += 1;
self.state = GeneticState::Simulate;
}
(GeneticState::Finish, Some(_)) => (),
_ => panic!("Error processing node {:?}", self.node),
}
Ok(self.state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Error;
use anyhow::anyhow;
use async_trait::async_trait;
#[derive(Deserialize, Serialize, Clone, Debug, PartialEq)]
struct TestState {
pub score: f64,
}
#[async_trait]
impl GeneticNode for TestState {
type Context = ();
async fn simulate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
self.score += 1.0;
Ok(())
}
async fn mutate(&mut self, _context: GeneticNodeContext<Self::Context>) -> Result<(), Error> {
Ok(())
}
async fn initialize(_context: GeneticNodeContext<Self::Context>) -> Result<Box<TestState>, Error> {
Ok(Box::new(TestState { score: 0.0 }))
}
async fn merge(_l: &TestState, _r: &TestState, _id: &Uuid, _: Self::Context) -> Result<Box<TestState>, Error> {
Err(Error::Other(anyhow!("Unable to merge")))
}
}
#[test]
fn test_new() -> Result<(), Error> {
let genetic_node = GeneticNodeWrapper::<TestState>::new(10);
let other_genetic_node = GeneticNodeWrapper::<TestState> {
node: None,
state: GeneticState::Initialize,
generation: 1,
max_generations: 10,
id: genetic_node.id(),
};
assert_eq!(genetic_node, other_genetic_node);
Ok(())
}
#[test]
fn test_from() -> Result<(), Error> {
let val = TestState { score: 0.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let other_genetic_node = GeneticNodeWrapper::<TestState> {
node: Some(val),
state: GeneticState::Simulate,
generation: 1,
max_generations: 10,
id: genetic_node.id(),
};
assert_eq!(genetic_node, other_genetic_node);
Ok(())
}
#[test]
fn test_as_ref() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let ref_value = genetic_node.as_ref().unwrap();
assert_eq!(*ref_value, val);
Ok(())
}
#[test]
fn test_id() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let id_value = genetic_node.id();
assert_eq!(id_value, uuid);
Ok(())
}
#[test]
fn test_max_generations() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let max_generations = genetic_node.max_generations();
assert_eq!(max_generations, 10);
Ok(())
}
#[test]
fn test_state() -> Result<(), Error> {
let val = TestState { score: 3.0 };
let uuid = Uuid::new_v4();
let genetic_node = GeneticNodeWrapper::from(val.clone(), 10, uuid);
let state = genetic_node.state();
assert_eq!(state, GeneticState::Simulate);
Ok(())
}
#[tokio::test]
async fn test_process_node() -> Result<(), Error> {
let mut genetic_node = GeneticNodeWrapper::<TestState>::new(2);
assert_eq!(genetic_node.state(), GeneticState::Initialize);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Mutate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Simulate);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
assert_eq!(genetic_node.process_node(()).await?, GeneticState::Finish);
Ok(())
}
}