diff --git a/crates/compute_graph/src/builder.rs b/crates/compute_graph/src/builder.rs index c199ea1..fe61194 100644 --- a/crates/compute_graph/src/builder.rs +++ b/crates/compute_graph/src/builder.rs @@ -242,7 +242,11 @@ impl GraphBuilder { let mut graph = self.node_graph.borrow_mut(); for (source, dest) in edges { - graph.add_edge(source, dest, ()); + // The graph may not contain the source node in the case of a removed child + // of a dynamic node. + if graph.contains_node(source) { + graph.add_edge(source, dest, ()); + } } util::remove_nodes_not_connected_to(&mut *graph, output.node_idx.get().unwrap()); diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 12226c4..8b8176c 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -260,6 +260,73 @@ impl Graph { } } +impl Graph { + fn process_update_step<'a>( + &'a mut self, + current_idx: NodeId, + ctx: NodeUpdateContext, + ) -> UpdateStepResult { + let mut graph = self.node_graph.borrow_mut(); + let mut nodes_changed = false; + for idx_to_remove in ctx.removed_nodes { + assert!( + idx_to_remove != current_idx, + "cannot remove node curently being evaluated" + ); + let (index_to_remove_in_sorted, _) = self + .sorted_nodes + .iter() + .enumerate() + .find(|(_, idx)| **idx == idx_to_remove) + .expect("removed node must have been already added"); + graph.remove_node(idx_to_remove); + self.sorted_nodes.remove(index_to_remove_in_sorted); + nodes_changed = true; + } + + for (added_node, id_cell) in ctx.added_nodes { + let id = graph.add_node(added_node); + id_cell.set(Some(id)); + nodes_changed = true; + } + + if nodes_changed { + // Update the graph before invalidating downstream nodes. + drop(graph); + self._modify(|_| {}) + .expect("modifying graph during evaluation must produce valid graph"); + graph = self.node_graph.borrow_mut(); + } + + if ctx.invalidate_dependent_nodes { + // Invalidate any downstream nodes (which we know we haven't visited yet, because + // we're iterating over a topological sort of the graph). + let dependents = graph + .edges_directed(current_idx, petgraph::Direction::Outgoing) + .map(|edge| edge.target()) + // Need to collect because the edges_directed iterator borrows the graph, and + // we need to mutably borrow to invalidate. + .collect::>(); + for dependent_idx in dependents { + let dependent = &mut graph[dependent_idx]; + dependent.invalidate(); + } + } + + if nodes_changed { + UpdateStepResult::Restart + } else { + UpdateStepResult::Continue + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum UpdateStepResult { + Continue, + Restart, +} + impl Graph { fn update_invalid_nodes(&mut self) { let mut graph = self.node_graph.borrow_mut(); @@ -272,57 +339,11 @@ impl Graph { let mut ctx = NodeUpdateContext::new(); node.update(&mut ctx); - let mut nodes_changed = false; - for idx_to_remove in ctx.removed_nodes { - assert!( - idx_to_remove != idx, - "cannot remove node curently being evaluated" - ); - let (index_to_remove_in_sorted, _) = self - .sorted_nodes - .iter() - .enumerate() - .find(|(_, idx)| **idx == idx_to_remove) - .expect("removed node must have been already added"); - assert!( - index_to_remove_in_sorted > i, - "cannot remove already evaluated node" - ); - graph.remove_node(idx_to_remove); - self.sorted_nodes.remove(index_to_remove_in_sorted); - nodes_changed = true; - } + drop(graph); + let result = self.process_update_step(idx, ctx); + graph = self.node_graph.borrow_mut(); - for (added_node, id_cell) in ctx.added_nodes { - let id = graph.add_node(added_node); - id_cell.set(Some(id)); - nodes_changed = true; - } - - if nodes_changed { - // Update the graph before invalidating downstream nodes. - drop(graph); - self._modify(|_| {}) - .expect("modifying graph during evaluation must produce valid graph"); - graph = self.node_graph.borrow_mut(); - } - - if ctx.invalidate_dependent_nodes { - // Invalidate any downstream nodes (which we know we haven't visited yet, because - // we're iterating over a topological sort of the graph). - let dependents = graph - .edges_directed(idx, petgraph::Direction::Outgoing) - .map(|edge| edge.target()) - // Need to collect because the edges_directed iterator borrows the graph, and - // we need to mutably borrow to invalidate. - .collect::>(); - for dependent_idx in dependents { - let dependent = &mut graph[dependent_idx]; - dependent.invalidate(); - } - } - - if nodes_changed { + if result == UpdateStepResult::Restart { // If we added/removed nodes, the sorted order has changed, so start evaluating // from the beginning, in case of changes before i. i = 0; @@ -371,57 +392,11 @@ impl Graph { let mut ctx = NodeUpdateContext::new(); node.update(&mut ctx).await; - let mut nodes_changed = false; - for idx_to_remove in ctx.removed_nodes { - assert!( - idx_to_remove != idx, - "cannot remove node curently being evaluated" - ); - let (index_to_remove_in_sorted, _) = self - .sorted_nodes - .iter() - .enumerate() - .find(|(_, idx)| **idx == idx_to_remove) - .expect("removed node must have been already added"); - assert!( - index_to_remove_in_sorted > i, - "cannot remove already evaluated node" - ); - graph.remove_node(idx_to_remove); - self.sorted_nodes.remove(index_to_remove_in_sorted); - nodes_changed = true; - } + drop(graph); + let result = self.process_update_step(idx, ctx); + graph = self.node_graph.borrow_mut(); - for (added_node, id_cell) in ctx.added_nodes { - let id = graph.add_node(added_node); - id_cell.set(Some(id)); - nodes_changed = true; - } - - if nodes_changed { - // Update the graph before invalidating downstream nodes. - drop(graph); - self._modify(|_| {}) - .expect("modifying graph during evaluation must produce valid graph"); - graph = self.node_graph.borrow_mut(); - } - - if ctx.invalidate_dependent_nodes { - // Invalidate any downstream nodes (which we know we haven't visited yet, because - // we're iterating over a topological sort of the graph). - let dependents = graph - .edges_directed(idx, petgraph::Direction::Outgoing) - .map(|edge| edge.target()) - // Need to collect because the edges_directed iterator borrows the graph, and - // we need to mutably borrow to invalidate. - .collect::>(); - for dependent_idx in dependents { - let dependent = &mut graph[dependent_idx]; - dependent.invalidate(); - } - } - - if nodes_changed { + if result == UpdateStepResult::Restart { // If we added/removed nodes, the sorted order has changed, so start evaluating // from the beginning, in case of changes before i. i = 0; @@ -433,7 +408,7 @@ impl Graph { } // Consistency check: after updating in the topological sort order, we should be left with - // no invalid nodes + // no invalid nodes. debug_assert!(self .sorted_nodes .iter() @@ -530,6 +505,8 @@ impl Clone for ValueInvalidationSignal { #[cfg(test)] mod tests { + use rule::DynamicNodeFactory; + use super::*; use crate::rule::{ AsyncDynamicRule, AsyncRule, ConstantRule, DynamicInput, DynamicRule, InputVisitable, Rule, @@ -838,10 +815,13 @@ mod tests { fn dynamic_rule() { let mut builder = GraphBuilder::new(); let (count, set_count) = builder.add_invalidatable_value(1); - struct CountUpTo(Input, Vec>); + struct CountUpTo { + count: Input, + node_factory: DynamicNodeFactory, + } impl InputVisitable for CountUpTo { fn visit_inputs(&self, visitor: &mut impl InputVisitor) { - visitor.visit(&self.0); + visitor.visit(&self.count); } } impl DynamicRule for CountUpTo { @@ -850,16 +830,17 @@ mod tests { &mut self, ctx: &mut impl rule::DynamicRuleContext, ) -> Vec> { - let count = *self.0.value(); - assert!(count >= self.1.len() as i32); - while (self.1.len() as i32) < count { - self.1 - .push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1))); + let count = *self.count.value(); + for i in 1..=count { + self.node_factory.add_rule(ctx, i, || ConstantRule::new(i)); } - self.1.clone() + self.node_factory.all_nodes(ctx) } } - let all_inputs = builder.add_dynamic_rule(CountUpTo(count, vec![])); + let all_inputs = builder.add_dynamic_rule(CountUpTo { + count, + node_factory: DynamicNodeFactory::new(), + }); struct Sum(DynamicInput); impl InputVisitable for Sum { fn visit_inputs(&self, visitor: &mut impl InputVisitor) { @@ -879,6 +860,8 @@ mod tests { assert_eq!(*graph.evaluate(), 3); set_count.set_value(4); assert_eq!(*graph.evaluate(), 10); + set_count.set_value(2); + assert_eq!(*graph.evaluate(), 3); println!("{}", graph.as_dot_string()); }