From 1530933464509c4f3ce8577ae4d34cdb472d4b68 Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Wed, 30 Oct 2024 23:32:45 -0400 Subject: [PATCH] Make the graph generic over whether it's sync/async --- crates/graph/Cargo.toml | 3 + crates/graph/src/lib.rs | 464 +++++++++++++++++++++++++++++++--------- 2 files changed, 367 insertions(+), 100 deletions(-) diff --git a/crates/graph/Cargo.toml b/crates/graph/Cargo.toml index f8194fe..2caaf97 100644 --- a/crates/graph/Cargo.toml +++ b/crates/graph/Cargo.toml @@ -7,3 +7,6 @@ edition = "2021" [dependencies] petgraph = "0.6.5" + +[dev-dependencies] +tokio = { version = "1.41.0", features = ["rt", "macros"] } diff --git a/crates/graph/src/lib.rs b/crates/graph/src/lib.rs index 509a472..42d281c 100644 --- a/crates/graph/src/lib.rs +++ b/crates/graph/src/lib.rs @@ -1,24 +1,117 @@ #![feature(let_chains)] +#![feature(async_closure)] mod util; use petgraph::visit::{IntoEdgeReferences, NodeIndexable}; use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef}; -use std::cell::{Cell, RefCell}; +use std::any::Any; +use std::cell::{Cell, Ref, RefCell, RefMut}; use std::collections::HashMap; +use std::collections::VecDeque; +use std::future::Future; +use std::ops::{Deref, DerefMut}; +use std::pin::Pin; use std::rc::Rc; -use std::{any::Any, collections::VecDeque}; -type NodeGraph = StableDiGraph; +// TODO: consider using a struct for this, because generic bounds of type aliases aren't enforced +type NodeGraph = StableDiGraph, (), u32>; -pub struct Graph { - // we treat this as a StableGraph, since nodes are never removed - node_graph: Rc>, +pub trait Synchronicity: 'static { + type AnyStorage; + fn make_any_storage(value: T) -> Self::AnyStorage; + fn unbox_any_storage(storage: &Self::AnyStorage) -> impl Deref; + fn unbox_any_storage_mut( + storage: &mut Self::AnyStorage, + ) -> impl DerefMut; + + type UpdateFn; + fn make_update_fn() -> Self::UpdateFn; + + type UpdateResult<'a>; + fn make_update_result<'a>() -> Self::UpdateResult<'a>; +} + +pub enum Synchronous {} + +impl Synchronicity for Synchronous { + type AnyStorage = Box; + + fn make_any_storage(value: T) -> Self::AnyStorage { + Box::new(value) + } + + fn unbox_any_storage(storage: &Self::AnyStorage) -> impl Deref { + storage.downcast_ref().unwrap() + } + + fn unbox_any_storage_mut( + storage: &mut Self::AnyStorage, + ) -> impl DerefMut { + storage.downcast_mut().unwrap() + } + + type UpdateFn = Box) -> ()>; + + fn make_update_fn() -> Self::UpdateFn { + Box::new(|any| { + let x = any.downcast_mut::>>().unwrap(); + x.update(); + }) + } + + type UpdateResult<'a> = (); + + fn make_update_result<'a>() -> Self::UpdateResult<'a> {} +} + +pub enum Asynchronous {} + +impl Synchronicity for Asynchronous { + type AnyStorage = Rc>>; + + fn make_any_storage(value: T) -> Self::AnyStorage { + Rc::new(RefCell::new(Box::new(value))) + } + + fn unbox_any_storage(storage: &Self::AnyStorage) -> impl Deref { + Ref::map(storage.borrow(), |any| any.downcast_ref().unwrap()) + } + + fn unbox_any_storage_mut( + storage: &mut Self::AnyStorage, + ) -> impl DerefMut { + RefMut::map(storage.borrow_mut(), |any| any.downcast_mut().unwrap()) + } + + type UpdateFn = Box>>) -> Pin>>>; + + fn make_update_fn() -> Self::UpdateFn { + Box::new(|any| Box::pin(Asynchronous::do_async_update::(any))) + } + + type UpdateResult<'a> = Pin + 'a>>; + + fn make_update_result<'a>() -> Self::UpdateResult<'a> { + Box::pin(std::future::ready(())) + } +} + +impl Asynchronous { + async fn do_async_update(any: Rc>>) { + let mut any_ = any.borrow_mut(); + let x = any_.downcast_mut::>>().unwrap(); + x.update().await; + } +} + +pub struct Graph { + node_graph: Rc>>, output: Option>, output_type: std::marker::PhantomData, } -impl Graph { +impl Graph { pub fn new() -> Self { Self { node_graph: Rc::new(RefCell::new(StableDiGraph::new())), @@ -26,13 +119,25 @@ impl Graph { output_type: std::marker::PhantomData, } } +} - pub fn set_output + 'static>(&mut self, rule: R) { +impl Graph { + pub fn new_async() -> Self { + Self { + node_graph: Rc::new(RefCell::new(StableDiGraph::new())), + output: None, + output_type: std::marker::PhantomData, + } + } +} + +impl Graph { + pub fn set_output + 'static>(&mut self, rule: R) { let input = self.add_rule(rule); self.output = Some(input.node_idx); } - fn add_node(&mut self, node: impl Node + 'static) -> Input { + fn add_node(&mut self, node: impl Node + 'static) -> Input { let value = node.value_rc(); let erased = ErasedNode::new(node); let idx = self.node_graph.borrow_mut().add_node(erased); @@ -43,7 +148,7 @@ impl Graph { } pub fn add_value(&mut self, value: V) -> Input { - return self.add_node(ConstNode(value.clone())); + return self.add_node(ConstNode::new(value.clone())); } pub fn add_rule + 'static, V: Clone + 'static>(&mut self, rule: R) -> Input { @@ -54,7 +159,7 @@ impl Graph { where R: Rule + 'static, V: Clone + 'static, - F: FnMut(InvalidationSignal) -> R, + F: FnMut(InvalidationSignal) -> R, { let node_idx = Rc::new(Cell::new(None)); let signal = InvalidationSignal { @@ -66,7 +171,7 @@ impl Graph { input } - pub fn freeze(self) -> Result, GraphFreezeError> { + pub fn freeze(self) -> Result, GraphFreezeError> { let output: NodeIndex = match self.output { None => return Err(GraphFreezeError::NoOutput), Some(idx) => idx, @@ -78,7 +183,7 @@ impl Graph { let mut edges = vec![]; for idx in indices { let node = &mut self.node_graph.borrow_mut()[idx]; - (node.visit_inputs)(&mut node.any, &mut |input_idx| { + node.visit_inputs(&mut |input_idx| { edges.push((input_idx, idx)); }); } @@ -108,57 +213,53 @@ impl Graph { } } +impl Graph { + pub fn set_async_output + 'static>(&mut self, rule: R) { + let input = self.add_async_rule(rule); + self.output = Some(input.node_idx); + } + + pub fn add_async_rule + 'static, V: Clone + 'static>( + &mut self, + rule: R, + ) -> Input { + self.add_node(AsyncRuleNode::new(rule)) + } + + pub fn add_invalidatable_async_rule(&mut self, mut f: F) -> Input + where + R: AsyncRule + 'static, + V: Clone + 'static, + F: FnMut(InvalidationSignal) -> R, + { + let node_idx = Rc::new(Cell::new(None)); + let signal = InvalidationSignal { + node_idx: Rc::clone(&node_idx), + graph: Rc::clone(&self.node_graph), + }; + let input = self.add_async_rule(f(signal)); + node_idx.set(Some(input.node_idx)); + input + } +} + #[derive(Debug)] pub enum GraphFreezeError { NoOutput, Cyclic(Vec>), } -pub struct FrozenGraph { - node_graph: Rc>, +pub struct FrozenGraph { + node_graph: Rc>>, output: NodeIndex, output_type: std::marker::PhantomData, } -impl FrozenGraph { - fn update_node(&mut self, idx: NodeIndex) { - let graph = self.node_graph.borrow(); - let node = &graph[idx]; - let is_valid = (node.is_valid)(&node.any); - drop(graph); - if !is_valid { - // collect all the edges into a vec so that we can mutably borrow the graph to update the nodes - let edge_sources = self - .node_graph - .borrow() - .edges_directed(idx, petgraph::Direction::Incoming) - .map(|edge| edge.source()) - .collect::>(); - - // Update the dependencies of this node. - // TODO: iterating/recursing here seems less than efficient - // instead, in evaluate, topo sort the graph and update invalid nodes? - for source in edge_sources { - self.update_node(source); - } - - let node = &mut self.node_graph.borrow_mut()[idx]; - // Actually update the node's value. - (node.update)(&mut node.any); - } - } - +impl FrozenGraph { pub fn is_output_valid(&self) -> bool { let graph = self.node_graph.borrow(); let node = &graph[self.output]; - (node.is_valid)(&node.any) - } - - pub fn evaluate(&mut self) -> Output { - self.update_node(self.output); - let graph = self.node_graph.borrow(); - let node = &graph[self.output].expect_type::(); - node.value_rc().borrow().clone().unwrap() + node.is_valid() } pub fn node_count(&self) -> usize { @@ -167,7 +268,7 @@ impl FrozenGraph { pub fn modify(&mut self, mut f: F) -> Result<(), GraphFreezeError> where - F: FnMut(&mut Graph) -> (), + F: FnMut(&mut Graph) -> (), { // Copy all the current edges so we can check if any change. let graph = self.node_graph.borrow(); @@ -206,12 +307,74 @@ impl FrozenGraph { to_invalidate.push_back(edge.target()); } } - invalidate_nodes(&mut graph, to_invalidate); + invalidate_nodes::(&mut graph, to_invalidate); Ok(()) } } +impl FrozenGraph { + fn update_node(&mut self, idx: NodeIndex) { + let graph = self.node_graph.borrow(); + let node = &graph[idx]; + if !node.is_valid() { + // collect all the edges into a vec so that we can mutably borrow the graph to update the nodes + let edge_sources = graph + .edges_directed(idx, petgraph::Direction::Incoming) + .map(|edge| edge.source()) + .collect::>(); + drop(graph); + + // Update the dependencies of this node. + // TODO: iterating/recursing here seems less than efficient + // instead, in evaluate, topo sort the graph and update invalid nodes? + for source in edge_sources { + self.update_node(source); + } + + let node = &mut self.node_graph.borrow_mut()[idx]; + // Actually update the node's value. + (node.update)(&mut node.any); + } + } + + pub fn evaluate(&mut self) -> Output { + self.update_node(self.output); + let graph = self.node_graph.borrow(); + let node = &graph[self.output].expect_type::(); + node.value_rc().borrow().clone().unwrap() + } +} + +impl FrozenGraph { + async fn update_node_async(&mut self, idx: NodeIndex) { + // TODO: same note about recursing as above, and consider doing this in parallel + let graph = self.node_graph.borrow(); + let node = &graph[idx]; + if !node.is_valid() { + let edge_sources = graph + .edges_directed(idx, petgraph::Direction::Incoming) + .map(|edge| edge.source()) + .collect::>(); + drop(graph); + + for source in edge_sources { + Box::pin(self.update_node_async(source)).await; + } + + let node = &self.node_graph.borrow()[idx]; + (node.update)(Rc::clone(&node.any)).await; + } + } + + pub async fn evaluate_async(&mut self) -> Output { + self.update_node_async(self.output).await; + let graph = self.node_graph.borrow(); + let node = &graph[self.output].expect_type::(); + node.value_rc().borrow().clone().unwrap() + } +} + #[derive(Clone, Debug)] pub struct Input { node_idx: NodeIndex, @@ -219,7 +382,7 @@ pub struct Input { } impl Input { - fn value(&self) -> T { + pub fn value(&self) -> T { self.value .as_ref() .borrow() @@ -229,24 +392,27 @@ impl Input { } // TODO: there's a lot happening here, make sure this doesn't create a reference cycle -pub struct InvalidationSignal { +pub struct InvalidationSignal { node_idx: Rc>>>, - graph: Rc>, + graph: Rc>>, } -impl InvalidationSignal { +impl InvalidationSignal { pub fn invalidate(&self) { let mut queue = VecDeque::new(); queue.push_back(self.node_idx.get().unwrap()); - invalidate_nodes(&mut *self.graph.borrow_mut(), queue); + invalidate_nodes::(&mut *self.graph.borrow_mut(), queue); } } -fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque>) { +fn invalidate_nodes( + graph: &mut NodeGraph, + mut queue: VecDeque>, +) { while let Some(idx) = queue.pop_front() { let node = &mut graph[idx]; - if (node.is_valid)(&node.any) { - (node.invalidate)(&mut node.any); + if node.is_valid() { + node.invalidate(); let dependents = graph .edges_directed(idx, petgraph::Direction::Outgoing) .map(|edge| edge.target()); @@ -257,62 +423,74 @@ fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque>) // TODO: i really want Input to be able to implement Deref somehow -struct ErasedNode { - any: Box, - is_valid: Box) -> bool>, - invalidate: Box) -> ()>, - visit_inputs: Box, &mut dyn FnMut(NodeIndex) -> ()) -> ()>, - update: Box) -> ()>, +pub struct ErasedNode { + any: Synch::AnyStorage, + is_valid: Box bool>, + invalidate: Box ()>, + visit_inputs: Box) -> ()) -> ()>, + update: Synch::UpdateFn, } -impl ErasedNode { - fn new + 'static, V: 'static>(base: N) -> Self { +impl ErasedNode { + fn new + 'static, V: 'static>(base: N) -> Self { // i don't love the double boxing, but i'm not sure how else to do this - let thing: Box> = Box::new(base); - let any: Box = Box::new(thing); + let thing: Box> = Box::new(base); Self { - any, + any: S::make_any_storage(thing), is_valid: Box::new(|any| { - let x = any.downcast_ref::>>().unwrap(); + let x = S::unbox_any_storage::>>(any); x.is_valid() }), invalidate: Box::new(|any| { - let x = any.downcast_mut::>>().unwrap(); + let mut x = S::unbox_any_storage_mut::>>(any); x.invalidate(); }), visit_inputs: Box::new(|any, visitor| { - let x = any.downcast_mut::>>().unwrap(); + let mut x = S::unbox_any_storage_mut::>>(any); x.visit_inputs(visitor); }), - update: Box::new(|any| { - let x = any.downcast_mut::>>().unwrap(); - x.update(); - }), + update: S::make_update_fn::(), } } - // TODO: revisit if these are necessary - fn expect_type<'a, V: 'static>(&'a self) -> &'a dyn Node { - let res = self - .any - .downcast_ref::>>() - .expect("matching node type"); - res.as_ref() + fn expect_type<'a, V: 'static>(&'a self) -> impl Deref>> + 'a { + S::unbox_any_storage::>>(&self.any) + } + + fn is_valid(&self) -> bool { + (self.is_valid)(&self.any) + } + fn invalidate(&mut self) { + (self.invalidate)(&mut self.any); + } + fn visit_inputs(&mut self, f: &mut dyn FnMut(NodeIndex) -> ()) { + (self.visit_inputs)(&mut self.any, f); } } -trait Node { +trait Node { fn is_valid(&self) -> bool; fn invalidate(&mut self); fn visit_inputs(&mut self, visitor: &mut dyn FnMut(NodeIndex) -> ()); - fn update(&mut self); - // TODO: are these both necessary? + fn update(&mut self) -> Synch::UpdateResult<'_>; fn value_rc(&self) -> Rc>>; } -struct ConstNode(V); +struct ConstNode { + value: V, + synchronicity: std::marker::PhantomData, +} -impl Node for ConstNode { +impl ConstNode { + fn new(value: V) -> Self { + Self { + value, + synchronicity: std::marker::PhantomData, + } + } +} + +impl Node for ConstNode { fn is_valid(&self) -> bool { true } @@ -321,30 +499,34 @@ impl Node for ConstNode { fn visit_inputs(&mut self, _visitor: &mut dyn FnMut(NodeIndex) -> ()) {} - fn update(&mut self) {} + fn update(&mut self) -> ::UpdateResult<'_> { + unreachable!() + } fn value_rc(&self) -> Rc>> { - Rc::new(RefCell::new(Some(self.0.clone()))) + Rc::new(RefCell::new(Some(self.value.clone()))) } } -struct RuleNode { +struct RuleNode { rule: R, value: Rc>>, valid: bool, + synchronicity: std::marker::PhantomData, } -impl, V> RuleNode { +impl, V, S> RuleNode { fn new(rule: R) -> Self { Self { rule, value: Rc::new(RefCell::new(None)), valid: false, + synchronicity: std::marker::PhantomData, } } } -impl + 'static, V: Clone + 'static> Node for RuleNode { +impl + 'static, V: Clone + 'static, S: Synchronicity> Node for RuleNode { fn is_valid(&self) -> bool { self.valid } @@ -363,10 +545,11 @@ impl + 'static, V: Clone + 'static> Node for RuleNode { self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); } - fn update(&mut self) { - self.valid = true; + fn update(&mut self) -> ::UpdateResult<'_> { let new_value = self.rule.evaluate(); + self.valid = true; *self.value.borrow_mut() = Some(new_value); + S::make_update_result() } fn value_rc(&self) -> Rc>> { @@ -374,12 +557,70 @@ impl + 'static, V: Clone + 'static> Node for RuleNode { } } +struct AsyncRuleNode { + rule: R, + value: Rc>>, + valid: bool, +} + +impl, V> AsyncRuleNode { + fn new(rule: R) -> Self { + Self { + rule, + value: Rc::new(RefCell::new(None)), + valid: false, + } + } +} + +impl + 'static, V: Clone + 'static> Node for AsyncRuleNode { + fn is_valid(&self) -> bool { + self.valid + } + + fn invalidate(&mut self) { + self.valid = false; + } + + fn visit_inputs(&mut self, visitor: &mut dyn FnMut(NodeIndex) -> ()) { + struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeIndex) -> ()); + impl<'a> InputVisitor for InputIndexVisitor<'a> { + fn visit(&mut self, input: &mut Input) { + self.0(input.node_idx); + } + } + self.rule.visit_inputs(&mut InputIndexVisitor(visitor)); + } + + fn update(&mut self) -> ::UpdateResult<'_> { + Box::pin(self.do_update()) + } + + fn value_rc(&self) -> Rc>> { + Rc::clone(&self.value) + } +} + +impl, V> AsyncRuleNode { + async fn do_update(&mut self) { + let new_value = self.rule.evaluate().await; + self.valid = true; + *self.value.borrow_mut() = Some(new_value); + } +} + pub trait Rule { fn visit_inputs(&mut self, visitor: &mut impl InputVisitor); fn evaluate(&mut self) -> Output; } +pub trait AsyncRule { + fn visit_inputs(&mut self, visitor: &mut impl InputVisitor); + + async fn evaluate(&mut self) -> Output; +} + pub trait InputVisitor { fn visit(&mut self, input: &mut Input); } @@ -390,7 +631,7 @@ mod tests { #[test] fn erase_node() { - let node = ErasedNode::new(ConstNode(1234 as i32)); + let node = ErasedNode::::new(ConstNode::new(1234 as i32)); let unwrapped = node.expect_type::(); assert_eq!(unwrapped.value_rc().borrow().unwrap(), 1234); } @@ -512,7 +753,7 @@ mod tests { #[test] fn cant_freeze_no_output() { - let graph = Graph::::new(); + let graph = Graph::::new(); match graph.freeze() { Err(GraphFreezeError::NoOutput) => (), Err(e) => assert!(false, "unexpected error {:?}", e), @@ -579,4 +820,27 @@ mod tests { assert!(!frozen.is_output_valid()); assert_eq!(frozen.evaluate(), 2); } + + #[tokio::test] + async fn async_graph() { + let mut graph = Graph::new_async(); + graph.set_output(ConstantRule(42)); + let mut frozen = graph.freeze().unwrap(); + assert_eq!(frozen.evaluate_async().await, 42); + } + + #[tokio::test] + async fn async_rule() { + struct AsyncConst(i32); + impl AsyncRule for AsyncConst { + fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {} + async fn evaluate(&mut self) -> i32 { + self.0 + } + } + let mut graph = Graph::new_async(); + graph.set_async_output(AsyncConst(42)); + let mut frozen = graph.freeze().unwrap(); + assert_eq!(frozen.evaluate_async().await, 42); + } }