diff --git a/crates/graph/src/builder.rs b/crates/graph/src/builder.rs index 68b85c9..157b80d 100644 --- a/crates/graph/src/builder.rs +++ b/crates/graph/src/builder.rs @@ -11,6 +11,7 @@ pub struct GraphBuilder { pub(crate) node_graph: Rc>>, pub(crate) output: Option>, pub(crate) output_type: std::marker::PhantomData, + pub(crate) is_valid: Rc>, } impl GraphBuilder { @@ -19,6 +20,7 @@ impl GraphBuilder { node_graph: Rc::new(RefCell::new(NodeGraph::new())), output: None, output_type: std::marker::PhantomData, + is_valid: Rc::new(Cell::new(false)), } } } @@ -29,6 +31,7 @@ impl GraphBuilder { node_graph: Rc::new(RefCell::new(NodeGraph::new())), output: None, output_type: std::marker::PhantomData, + is_valid: Rc::new(Cell::new(false)), } } } @@ -69,6 +72,7 @@ impl GraphBuilder { let signal = InvalidationSignal { node_idx: Rc::clone(&node_idx), graph: Rc::clone(&self.node_graph), + graph_is_valid: Rc::clone(&self.is_valid), }; let input = self.add_rule(f(signal)); node_idx.set(Some(input.node_idx)); @@ -100,17 +104,18 @@ impl GraphBuilder { util::remove_nodes_not_connected_to(&mut *graph, output.node_idx); drop(graph); - let sorted = petgraph::algo::toposort(&**self.node_graph.borrow(), None); - if let Err(_cycle) = sorted { - self.node_graph.borrow_mut().clear_edges(); - // TODO: actually build a vec describing the cycle path for debugging - return Err(BuildGraphError::Cyclic(vec![])); - } + let sorted_nodes = + petgraph::algo::toposort(&**self.node_graph.borrow(), None).map_err(|_| { + // TODO: actually build a vec describing the cycle path for debugging + BuildGraphError::Cyclic(vec![]) + })?; let graph = Graph { node_graph: self.node_graph, output: self.output.unwrap(), output_type: std::marker::PhantomData, + sorted_nodes, + is_valid: self.is_valid, }; Ok(graph) @@ -139,6 +144,7 @@ impl GraphBuilder { let signal = InvalidationSignal { node_idx: Rc::clone(&node_idx), graph: Rc::clone(&self.node_graph), + graph_is_valid: Rc::clone(&self.is_valid), }; let input = self.add_async_rule(f(signal)); node_idx.set(Some(input.node_idx)); diff --git a/crates/graph/src/lib.rs b/crates/graph/src/lib.rs index 2e95231..26c68da 100644 --- a/crates/graph/src/lib.rs +++ b/crates/graph/src/lib.rs @@ -42,13 +42,16 @@ pub struct Graph { node_graph: Rc>>, output: Input, output_type: std::marker::PhantomData, + // The topological sort of nodes in the graph. + sorted_nodes: Vec, + is_valid: Rc>, } impl Graph { pub fn is_output_valid(&self) -> bool { let graph = self.node_graph.borrow(); let node = &graph[self.output.node_idx]; - node.is_valid() + self.is_valid.get() && node.is_valid() } pub fn node_count(&self) -> usize { @@ -75,10 +78,14 @@ impl Graph { graph.clear_edges(); drop(graph); + let old_output = self.output.node_idx; + let mut graph = GraphBuilder { + // TODO: is using the same graph as self correct? if the modify fails, is it left in a bad state? node_graph: Rc::clone(&self.node_graph), output: Some(self.output.clone()), output_type: std::marker::PhantomData, + is_valid: Rc::clone(&self.is_valid), }; f(&mut graph); *self = graph.build()?; @@ -96,67 +103,91 @@ impl Graph { to_invalidate.push_back(edge.target()); } } - invalidate_nodes::(&mut graph, to_invalidate); + // Edge case: if the only node in the graph is the output node, and it's replaced in the modify block, + // there are no edges but we still need to invalidate. + if !to_invalidate.is_empty() || self.output.node_idx != old_output { + self.is_valid.set(false); + for idx in to_invalidate { + let node = &mut graph[idx]; + node.invalidate(); + } + } Ok(()) } } impl Graph { - fn update_node(&mut self, idx: NodeId) { - let graph = self.node_graph.borrow(); - let node = &graph[idx]; - if !node.is_valid() { - // collect all the edges into a vec so that we can mutably borrow the graph to update the nodes - let edge_sources = graph - .edges_directed(idx, petgraph::Direction::Incoming) - .map(|edge| edge.source()) - .collect::>(); - drop(graph); + fn update_invalid_nodes(&mut self) { + let mut graph = self.node_graph.borrow_mut(); + for &idx in self.sorted_nodes.iter() { + let node = &mut graph[idx]; + if !node.is_valid() { + // Update this node + node.update(); - // Update the dependencies of this node. - // TODO: iterating/recursing here seems less than efficient - // instead, in evaluate, topo sort the graph and update invalid nodes? - for source in edge_sources { - self.update_node(source); + // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) + let dependents = graph + .edges_directed(idx, petgraph::Direction::Outgoing) + .map(|edge| edge.target()) + // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate + .collect::>(); + for dependent_idx in dependents { + let dependent = &mut graph[dependent_idx]; + dependent.invalidate(); + } } - - let node = &mut self.node_graph.borrow_mut()[idx]; - // Actually update the node's value. - node.update(); } + // Consistency check: after updating in the topological sort order, we should be left with no invalid nodes + debug_assert!(self + .sorted_nodes + .iter() + .all(|&idx| { (&graph[idx]).is_valid() })); + self.is_valid.set(true); } pub fn evaluate(&mut self) -> impl Deref + '_ { - self.update_node(self.output.node_idx); + if !self.is_valid.get() { + self.update_invalid_nodes(); + } self.output.value() } } impl Graph { - async fn update_node_async(&mut self, idx: NodeId) { - // TODO: same note about recursing as above, and consider doing this in parallel - let graph = self.node_graph.borrow(); - let node = &graph[idx]; - if !node.is_valid() { - let edge_sources = graph - .edges_directed(idx, petgraph::Direction::Incoming) - .map(|edge| edge.source()) - .collect::>(); - drop(graph); - - for source in edge_sources { - Box::pin(self.update_node_async(source)).await; - } - - let mut graph = self.node_graph.borrow_mut(); + async fn update_invalid_nodes(&mut self) { + // TODO: consider whether this can be done in parallel to any degree. + let mut graph = self.node_graph.borrow_mut(); + for &idx in self.sorted_nodes.iter() { let node = &mut graph[idx]; - node.update().await; + if !node.is_valid() { + // Update this node + node.update().await; + + // Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph) + let dependents = graph + .edges_directed(idx, petgraph::Direction::Outgoing) + .map(|edge| edge.target()) + // Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate + .collect::>(); + for dependent_idx in dependents { + let dependent = &mut graph[dependent_idx]; + dependent.invalidate(); + } + } } + // Consistency check: after updating in the topological sort order, we should be left with no invalid nodes + debug_assert!(self + .sorted_nodes + .iter() + .all(|&idx| { (&graph[idx]).is_valid() })); + self.is_valid.set(true); } pub async fn evaluate_async(&mut self) -> impl Deref + '_ { - self.update_node_async(self.output.node_idx).await; + if !self.is_valid.get() { + self.update_invalid_nodes().await; + } self.output.value() } } @@ -190,26 +221,15 @@ impl Clone for Input { pub struct InvalidationSignal { node_idx: Rc>>, graph: Rc>>, + graph_is_valid: Rc>, } impl InvalidationSignal { pub fn invalidate(&self) { - let mut queue = VecDeque::new(); - queue.push_back(self.node_idx.get().unwrap()); - invalidate_nodes::(&mut *self.graph.borrow_mut(), queue); - } -} - -fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque) { - while let Some(idx) = queue.pop_front() { - let node = &mut graph[idx]; - if node.is_valid() { - node.invalidate(); - let dependents = graph - .edges_directed(idx, petgraph::Direction::Outgoing) - .map(|edge| edge.target()); - queue.extend(dependents); - } + self.graph_is_valid.set(false); + let mut graph = self.graph.borrow_mut(); + let node = &mut graph[self.node_idx.get().unwrap()]; + node.invalidate(); } }