diff --git a/crates/graph/src/builder.rs b/crates/graph/src/builder.rs new file mode 100644 index 0000000..68b85c9 --- /dev/null +++ b/crates/graph/src/builder.rs @@ -0,0 +1,153 @@ +use crate::node::{AsyncRuleNode, ConstNode, Node, RuleNode}; +use crate::util; +use crate::{ + AsyncRule, Asynchronous, ErasedNode, Graph, Input, InvalidationSignal, NodeGraph, NodeId, Rule, + Synchronicity, Synchronous, +}; +use std::cell::{Cell, RefCell}; +use std::rc::Rc; + +pub struct GraphBuilder { + pub(crate) node_graph: Rc>>, + pub(crate) output: Option>, + pub(crate) output_type: std::marker::PhantomData, +} + +impl GraphBuilder { + pub fn new() -> Self { + Self { + node_graph: Rc::new(RefCell::new(NodeGraph::new())), + output: None, + output_type: std::marker::PhantomData, + } + } +} + +impl GraphBuilder { + pub fn new_async() -> Self { + Self { + node_graph: Rc::new(RefCell::new(NodeGraph::new())), + output: None, + output_type: std::marker::PhantomData, + } + } +} + +impl GraphBuilder { + pub fn set_output>(&mut self, rule: R) { + let input = self.add_rule(rule); + self.output = Some(input); + } + + fn add_node(&mut self, node: impl Node + 'static) -> Input { + let value = Rc::clone(node.value_rc()); + let erased = ErasedNode::new(node); + let idx = self.node_graph.borrow_mut().add_node(erased); + Input { + node_idx: idx, + value, + } + } + + pub fn add_value(&mut self, value: V) -> Input { + return self.add_node(ConstNode::new(value)); + } + + pub fn add_rule(&mut self, rule: R) -> Input + where + R: Rule, + { + return self.add_node(RuleNode::new(rule)); + } + + pub fn add_invalidatable_rule(&mut self, mut f: F) -> Input + where + R: Rule, + F: FnMut(InvalidationSignal) -> R, + { + let node_idx = Rc::new(Cell::new(None)); + let signal = InvalidationSignal { + node_idx: Rc::clone(&node_idx), + graph: Rc::clone(&self.node_graph), + }; + let input = self.add_rule(f(signal)); + node_idx.set(Some(input.node_idx)); + input + } + + pub fn build(self) -> Result, BuildGraphError> { + let output: &Input = match &self.output { + None => return Err(BuildGraphError::NoOutput), + Some(output) => output, + }; + + let graph = self.node_graph.borrow(); + let indices = graph.node_indices().collect::>(); + drop(graph); + let mut edges = vec![]; + for idx in indices { + let node = &mut self.node_graph.borrow_mut()[idx]; + node.visit_inputs(&mut |input_idx| { + edges.push((input_idx, idx)); + }); + } + + for (source, dest) in edges { + self.node_graph.borrow_mut().add_edge(source, dest, ()); + } + + let mut graph = self.node_graph.borrow_mut(); + util::remove_nodes_not_connected_to(&mut *graph, output.node_idx); + drop(graph); + + let sorted = petgraph::algo::toposort(&**self.node_graph.borrow(), None); + if let Err(_cycle) = sorted { + self.node_graph.borrow_mut().clear_edges(); + // TODO: actually build a vec describing the cycle path for debugging + return Err(BuildGraphError::Cyclic(vec![])); + } + + let graph = Graph { + node_graph: self.node_graph, + output: self.output.unwrap(), + output_type: std::marker::PhantomData, + }; + + Ok(graph) + } +} + +impl GraphBuilder { + pub fn set_async_output>(&mut self, rule: R) { + let input = self.add_async_rule(rule); + self.output = Some(input); + } + + pub fn add_async_rule(&mut self, rule: R) -> Input + where + R: AsyncRule, + { + self.add_node(AsyncRuleNode::new(rule)) + } + + pub fn add_invalidatable_async_rule(&mut self, mut f: F) -> Input + where + R: AsyncRule, + F: FnMut(InvalidationSignal) -> R, + { + let node_idx = Rc::new(Cell::new(None)); + let signal = InvalidationSignal { + node_idx: Rc::clone(&node_idx), + graph: Rc::clone(&self.node_graph), + }; + let input = self.add_async_rule(f(signal)); + node_idx.set(Some(input.node_idx)); + input + } +} + +#[derive(Debug)] +pub enum BuildGraphError { + NoOutput, + Cyclic(Vec), +} diff --git a/crates/graph/src/lib.rs b/crates/graph/src/lib.rs index b9e22bc..2e95231 100644 --- a/crates/graph/src/lib.rs +++ b/crates/graph/src/lib.rs @@ -1,18 +1,28 @@ +mod builder; +mod node; +mod synchronicity; mod util; +use builder::{BuildGraphError, GraphBuilder}; +use node::ErasedNode; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; -use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef}; -use std::any::Any; +use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef}; use std::cell::{Cell, Ref, RefCell}; use std::collections::HashMap; use std::collections::VecDeque; -use std::future::Future; use std::ops::{Deref, DerefMut}; -use std::pin::Pin; use std::rc::Rc; +use synchronicity::*; // use a struct for this, not a type alias, because generic bounds of type aliases aren't enforced struct NodeGraph(StableDiGraph, (), u32>); +type NodeId = petgraph::stable_graph::NodeIndex; + +impl NodeGraph { + fn new() -> Self { + Self(StableDiGraph::new()) + } +} impl Deref for NodeGraph { type Target = StableDiGraph, (), u32>; @@ -28,200 +38,6 @@ impl DerefMut for NodeGraph { } } -pub trait Synchronicity: 'static { - type UpdateFn; - fn make_update_fn() -> Self::UpdateFn; - - type UpdateResult<'a>; - fn make_update_result<'a>() -> Self::UpdateResult<'a>; -} - -pub struct Synchronous; - -impl Synchronicity for Synchronous { - type UpdateFn = Box) -> ()>; - - fn make_update_fn() -> Self::UpdateFn { - Box::new(|any| { - let x = any.downcast_mut::>>().unwrap(); - x.update(); - }) - } - - type UpdateResult<'a> = (); - - fn make_update_result<'a>() -> Self::UpdateResult<'a> {} -} - -pub struct Asynchronous; - -impl Synchronicity for Asynchronous { - type UpdateFn = - Box Fn(&'a mut Box) -> Pin + 'a>>>; - - fn make_update_fn() -> Self::UpdateFn { - Box::new(|any| Box::pin(Asynchronous::do_async_update::(any))) - } - - type UpdateResult<'a> = Pin + 'a>>; - - fn make_update_result<'a>() -> Self::UpdateResult<'a> { - Box::pin(std::future::ready(())) - } -} - -impl Asynchronous { - async fn do_async_update(any: &mut Box) { - let x = any.downcast_mut::>>().unwrap(); - x.update().await; - } -} - -pub struct GraphBuilder { - node_graph: Rc>>, - output: Option>, - output_type: std::marker::PhantomData, -} - -impl GraphBuilder { - pub fn new() -> Self { - Self { - node_graph: Rc::new(RefCell::new(NodeGraph(StableDiGraph::new()))), - output: None, - output_type: std::marker::PhantomData, - } - } -} - -impl GraphBuilder { - pub fn new_async() -> Self { - Self { - node_graph: Rc::new(RefCell::new(NodeGraph(StableDiGraph::new()))), - output: None, - output_type: std::marker::PhantomData, - } - } -} - -impl GraphBuilder { - pub fn set_output>(&mut self, rule: R) { - let input = self.add_rule(rule); - self.output = Some(input); - } - - fn add_node(&mut self, node: impl Node + 'static) -> Input { - let value = Rc::clone(node.value_rc()); - let erased = ErasedNode::new(node); - let idx = self.node_graph.borrow_mut().add_node(erased); - Input { - node_idx: idx, - value, - } - } - - pub fn add_value(&mut self, value: V) -> Input { - return self.add_node(ConstNode::new(value)); - } - - pub fn add_rule(&mut self, rule: R) -> Input - where - R: Rule, - { - return self.add_node(RuleNode::new(rule)); - } - - pub fn add_invalidatable_rule(&mut self, mut f: F) -> Input - where - R: Rule, - F: FnMut(InvalidationSignal) -> R, - { - let node_idx = Rc::new(Cell::new(None)); - let signal = InvalidationSignal { - node_idx: Rc::clone(&node_idx), - graph: Rc::clone(&self.node_graph), - }; - let input = self.add_rule(f(signal)); - node_idx.set(Some(input.node_idx)); - input - } - - pub fn build(self) -> Result, BuildGraphError> { - let output: &Input = match &self.output { - None => return Err(BuildGraphError::NoOutput), - Some(output) => output, - }; - - let graph = self.node_graph.borrow(); - let indices = graph.node_indices().collect::>(); - drop(graph); - let mut edges = vec![]; - for idx in indices { - let node = &mut self.node_graph.borrow_mut()[idx]; - node.visit_inputs(&mut |input_idx| { - edges.push((input_idx, idx)); - }); - } - - for (source, dest) in edges { - self.node_graph.borrow_mut().add_edge(source, dest, ()); - } - - let mut graph = self.node_graph.borrow_mut(); - util::remove_nodes_not_connected_to(&mut *graph, output.node_idx); - drop(graph); - - let sorted = petgraph::algo::toposort(&**self.node_graph.borrow(), None); - if let Err(_cycle) = sorted { - self.node_graph.borrow_mut().clear_edges(); - // TODO: actually build a vec describing the cycle path for debugging - return Err(BuildGraphError::Cyclic(vec![])); - } - - let graph = Graph { - node_graph: self.node_graph, - output: self.output.unwrap(), - output_type: std::marker::PhantomData, - }; - - Ok(graph) - } -} - -impl GraphBuilder { - pub fn set_async_output>(&mut self, rule: R) { - let input = self.add_async_rule(rule); - self.output = Some(input); - } - - pub fn add_async_rule(&mut self, rule: R) -> Input - where - R: AsyncRule, - { - self.add_node(AsyncRuleNode::new(rule)) - } - - pub fn add_invalidatable_async_rule(&mut self, mut f: F) -> Input - where - R: AsyncRule, - F: FnMut(InvalidationSignal) -> R, - { - let node_idx = Rc::new(Cell::new(None)); - let signal = InvalidationSignal { - node_idx: Rc::clone(&node_idx), - graph: Rc::clone(&self.node_graph), - }; - let input = self.add_async_rule(f(signal)); - node_idx.set(Some(input.node_idx)); - input - } -} - -#[derive(Debug)] -pub enum BuildGraphError { - NoOutput, - Cyclic(Vec>), -} - pub struct Graph { node_graph: Rc>>, output: Input, @@ -287,7 +103,7 @@ impl Graph { } impl Graph { - fn update_node(&mut self, idx: NodeIndex) { + fn update_node(&mut self, idx: NodeId) { let graph = self.node_graph.borrow(); let node = &graph[idx]; if !node.is_valid() { @@ -307,7 +123,7 @@ impl Graph { let node = &mut self.node_graph.borrow_mut()[idx]; // Actually update the node's value. - (node.update)(&mut node.any); + node.update(); } } @@ -318,7 +134,7 @@ impl Graph { } impl Graph { - async fn update_node_async(&mut self, idx: NodeIndex) { + async fn update_node_async(&mut self, idx: NodeId) { // TODO: same note about recursing as above, and consider doing this in parallel let graph = self.node_graph.borrow(); let node = &graph[idx]; @@ -335,7 +151,7 @@ impl Graph { let mut graph = self.node_graph.borrow_mut(); let node = &mut graph[idx]; - (node.update)(&mut node.any).await; + node.update().await; } } @@ -347,7 +163,7 @@ impl Graph { #[derive(Debug)] pub struct Input { - node_idx: NodeIndex, + node_idx: NodeId, value: Rc>>, } @@ -372,7 +188,7 @@ impl Clone for Input { // TODO: there's a lot happening here, make sure this doesn't create a reference cycle pub struct InvalidationSignal { - node_idx: Rc>>>, + node_idx: Rc>>, graph: Rc>>, } @@ -384,10 +200,7 @@ impl InvalidationSignal { } } -fn invalidate_nodes( - graph: &mut NodeGraph, - mut queue: VecDeque>, -) { +fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque) { while let Some(idx) = queue.pop_front() { let node = &mut graph[idx]; if node.is_valid() { @@ -402,189 +215,6 @@ fn invalidate_nodes( // TODO: i really want Input to be able to implement Deref somehow -pub struct ErasedNode { - any: Box, - is_valid: Box) -> bool>, - invalidate: Box) -> ()>, - visit_inputs: Box, &mut dyn FnMut(NodeIndex) -> ()) -> ()>, - update: Synch::UpdateFn, -} - -impl ErasedNode { - fn new + 'static, V: 'static>(base: N) -> Self { - // i don't love the double boxing, but i'm not sure how else to do this - let thing: Box> = Box::new(base); - let any: Box = Box::new(thing); - Self { - any, - is_valid: Box::new(|any| { - let x = any.downcast_ref::>>().unwrap(); - x.is_valid() - }), - invalidate: Box::new(|any| { - let x = any.downcast_mut::>>().unwrap(); - x.invalidate(); - }), - visit_inputs: Box::new(|any, visitor| { - let x = any.downcast_ref::>>().unwrap(); - x.visit_inputs(visitor); - }), - update: S::make_update_fn::(), - } - } - - fn is_valid(&self) -> bool { - (self.is_valid)(&self.any) - } - fn invalidate(&mut self) { - (self.invalidate)(&mut self.any); - } - fn visit_inputs(&self, f: &mut dyn FnMut(NodeIndex) -> ()) { - (self.visit_inputs)(&self.any, f); - } -} - -trait Node { - fn is_valid(&self) -> bool; - fn invalidate(&mut self); - fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeIndex) -> ()); - fn update(&mut self) -> Synch::UpdateResult<'_>; - fn value_rc(&self) -> &Rc>>; -} - -struct ConstNode { - value: Rc>>, - synchronicity: std::marker::PhantomData, -} - -impl ConstNode { - fn new(value: V) -> Self { - Self { - value: Rc::new(RefCell::new(Some(value))), - synchronicity: std::marker::PhantomData, - } - } -} - -impl Node for ConstNode { - fn is_valid(&self) -> bool { - true - } - - fn invalidate(&mut self) {} - - fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeIndex) -> ()) {} - - fn update(&mut self) -> S::UpdateResult<'_> { - unreachable!() - } - - fn value_rc(&self) -> &Rc>> { - &self.value - } -} - -struct RuleNode { - rule: R, - value: Rc>>, - valid: bool, - synchronicity: std::marker::PhantomData, -} - -impl RuleNode { - fn new(rule: R) -> Self { - Self { - rule, - value: Rc::new(RefCell::new(None)), - valid: false, - synchronicity: std::marker::PhantomData, - } - } -} - -impl Node for RuleNode { - fn is_valid(&self) -> bool { - self.valid - } - - fn invalidate(&mut self) { - self.valid = false; - } - - fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeIndex) -> ()) { - struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeIndex) -> ()); - impl<'a> InputVisitor for InputIndexVisitor<'a> { - fn visit(&mut self, input: &Input) { - self.0(input.node_idx); - } - } - self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); - } - - fn update(&mut self) -> S::UpdateResult<'_> { - let new_value = self.rule.evaluate(); - self.valid = true; - *self.value.borrow_mut() = Some(new_value); - S::make_update_result() - } - - fn value_rc(&self) -> &Rc>> { - &self.value - } -} - -struct AsyncRuleNode { - rule: R, - value: Rc>>, - valid: bool, -} - -impl AsyncRuleNode { - fn new(rule: R) -> Self { - Self { - rule, - value: Rc::new(RefCell::new(None)), - valid: false, - } - } -} - -impl Node for AsyncRuleNode { - fn is_valid(&self) -> bool { - self.valid - } - - fn invalidate(&mut self) { - self.valid = false; - } - - fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeIndex) -> ()) { - struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeIndex) -> ()); - impl<'a> InputVisitor for InputIndexVisitor<'a> { - fn visit(&mut self, input: &Input) { - self.0(input.node_idx); - } - } - self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); - } - - fn update(&mut self) -> ::UpdateResult<'_> { - Box::pin(self.do_update()) - } - - fn value_rc(&self) -> &Rc>> { - &self.value - } -} - -impl AsyncRuleNode { - async fn do_update(&mut self) { - let new_value = self.rule.evaluate().await; - self.valid = true; - *self.value.borrow_mut() = Some(new_value); - } -} - pub trait Rule: 'static { type Output; diff --git a/crates/graph/src/node.rs b/crates/graph/src/node.rs new file mode 100644 index 0000000..a4e3da5 --- /dev/null +++ b/crates/graph/src/node.rs @@ -0,0 +1,200 @@ +use crate::synchronicity::{Asynchronous, Synchronicity}; +use crate::{AsyncRule, Input, InputVisitor, NodeId, Rule, Synchronous}; +use std::any::Any; +use std::cell::RefCell; +use std::rc::Rc; + +pub(crate) struct ErasedNode { + any: Box, + is_valid: Box) -> bool>, + invalidate: Box) -> ()>, + visit_inputs: Box, &mut dyn FnMut(NodeId) -> ()) -> ()>, + update: Synch::UpdateFn, +} + +impl ErasedNode { + pub(crate) fn new + 'static, V: 'static>(base: N) -> Self { + // i don't love the double boxing, but i'm not sure how else to do this + let thing: Box> = Box::new(base); + let any: Box = Box::new(thing); + Self { + any, + is_valid: Box::new(|any| { + let x = any.downcast_ref::>>().unwrap(); + x.is_valid() + }), + invalidate: Box::new(|any| { + let x = any.downcast_mut::>>().unwrap(); + x.invalidate(); + }), + visit_inputs: Box::new(|any, visitor| { + let x = any.downcast_ref::>>().unwrap(); + x.visit_inputs(visitor); + }), + update: S::make_update_fn::(), + } + } + + pub(crate) fn is_valid(&self) -> bool { + (self.is_valid)(&self.any) + } + pub(crate) fn invalidate(&mut self) { + (self.invalidate)(&mut self.any); + } + pub(crate) fn visit_inputs(&self, f: &mut dyn FnMut(NodeId) -> ()) { + (self.visit_inputs)(&self.any, f); + } +} + +impl ErasedNode { + pub(crate) fn update(&mut self) { + (self.update)(&mut self.any) + } +} + +impl ErasedNode { + pub(crate) async fn update(&mut self) { + (self.update)(&mut self.any).await + } +} + +pub(crate) trait Node { + fn is_valid(&self) -> bool; + fn invalidate(&mut self); + fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()); + fn update(&mut self) -> Synch::UpdateResult<'_>; + fn value_rc(&self) -> &Rc>>; +} + +pub(crate) struct ConstNode { + value: Rc>>, + synchronicity: std::marker::PhantomData, +} + +impl ConstNode { + pub(crate) fn new(value: V) -> Self { + Self { + value: Rc::new(RefCell::new(Some(value))), + synchronicity: std::marker::PhantomData, + } + } +} + +impl Node for ConstNode { + fn is_valid(&self) -> bool { + true + } + + fn invalidate(&mut self) {} + + fn visit_inputs(&self, _visitor: &mut dyn FnMut(NodeId) -> ()) {} + + fn update(&mut self) -> S::UpdateResult<'_> { + unreachable!() + } + + fn value_rc(&self) -> &Rc>> { + &self.value + } +} + +pub(crate) struct RuleNode { + rule: R, + value: Rc>>, + valid: bool, + synchronicity: std::marker::PhantomData, +} + +impl RuleNode { + pub(crate) fn new(rule: R) -> Self { + Self { + rule, + value: Rc::new(RefCell::new(None)), + valid: false, + synchronicity: std::marker::PhantomData, + } + } +} + +impl Node for RuleNode { + fn is_valid(&self) -> bool { + self.valid + } + + fn invalidate(&mut self) { + self.valid = false; + } + + fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { + struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); + impl<'a> InputVisitor for InputIndexVisitor<'a> { + fn visit(&mut self, input: &Input) { + self.0(input.node_idx); + } + } + self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); + } + + fn update(&mut self) -> S::UpdateResult<'_> { + let new_value = self.rule.evaluate(); + self.valid = true; + *self.value.borrow_mut() = Some(new_value); + S::make_update_result() + } + + fn value_rc(&self) -> &Rc>> { + &self.value + } +} + +pub(crate) struct AsyncRuleNode { + rule: R, + value: Rc>>, + valid: bool, +} + +impl AsyncRuleNode { + pub(crate) fn new(rule: R) -> Self { + Self { + rule, + value: Rc::new(RefCell::new(None)), + valid: false, + } + } +} + +impl Node for AsyncRuleNode { + fn is_valid(&self) -> bool { + self.valid + } + + fn invalidate(&mut self) { + self.valid = false; + } + + fn visit_inputs(&self, visitor: &mut dyn FnMut(NodeId) -> ()) { + struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeId) -> ()); + impl<'a> InputVisitor for InputIndexVisitor<'a> { + fn visit(&mut self, input: &Input) { + self.0(input.node_idx); + } + } + self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); + } + + fn update(&mut self) -> ::UpdateResult<'_> { + Box::pin(self.do_update()) + } + + fn value_rc(&self) -> &Rc>> { + &self.value + } +} + +impl AsyncRuleNode { + async fn do_update(&mut self) { + let new_value = self.rule.evaluate().await; + self.valid = true; + *self.value.borrow_mut() = Some(new_value); + } +} diff --git a/crates/graph/src/synchronicity.rs b/crates/graph/src/synchronicity.rs new file mode 100644 index 0000000..f5e95e9 --- /dev/null +++ b/crates/graph/src/synchronicity.rs @@ -0,0 +1,53 @@ +use crate::node::Node; +use std::any::Any; +use std::future::Future; +use std::pin::Pin; + +pub trait Synchronicity: 'static { + type UpdateFn; + fn make_update_fn() -> Self::UpdateFn; + + type UpdateResult<'a>; + fn make_update_result<'a>() -> Self::UpdateResult<'a>; +} + +pub struct Synchronous; + +impl Synchronicity for Synchronous { + type UpdateFn = Box) -> ()>; + + fn make_update_fn() -> Self::UpdateFn { + Box::new(|any| { + let x = any.downcast_mut::>>().unwrap(); + x.update(); + }) + } + + type UpdateResult<'a> = (); + + fn make_update_result<'a>() -> Self::UpdateResult<'a> {} +} + +pub struct Asynchronous; + +impl Synchronicity for Asynchronous { + type UpdateFn = + Box Fn(&'a mut Box) -> Pin + 'a>>>; + + fn make_update_fn() -> Self::UpdateFn { + Box::new(|any| Box::pin(Asynchronous::do_async_update::(any))) + } + + type UpdateResult<'a> = Pin + 'a>>; + + fn make_update_result<'a>() -> Self::UpdateResult<'a> { + Box::pin(std::future::ready(())) + } +} + +impl Asynchronous { + async fn do_async_update(any: &mut Box) { + let x = any.downcast_mut::>>().unwrap(); + x.update().await; + } +}