Actually make removing dynamic node children work

This commit is contained in:
Shadowfacts 2024-12-31 18:18:18 -05:00
parent 6bb51638cc
commit f44f525c2c
2 changed files with 98 additions and 111 deletions

View File

@ -242,8 +242,12 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
let mut graph = self.node_graph.borrow_mut();
for (source, dest) in edges {
// The graph may not contain the source node in the case of a removed child
// of a dynamic node.
if graph.contains_node(source) {
graph.add_edge(source, dest, ());
}
}
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx.get().unwrap());

View File

@ -260,22 +260,17 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
}
}
impl<O: 'static> Graph<O, Synchronous> {
fn update_invalid_nodes(&mut self) {
impl<O: 'static, S: Synchronicity> Graph<O, S> {
fn process_update_step<'a>(
&'a mut self,
current_idx: NodeId,
ctx: NodeUpdateContext<S>,
) -> UpdateStepResult {
let mut graph = self.node_graph.borrow_mut();
let mut i = 0;
while i < self.sorted_nodes.len() {
let idx = self.sorted_nodes[i];
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx);
let mut nodes_changed = false;
for idx_to_remove in ctx.removed_nodes {
assert!(
idx_to_remove != idx,
idx_to_remove != current_idx,
"cannot remove node curently being evaluated"
);
let (index_to_remove_in_sorted, _) = self
@ -284,10 +279,6 @@ impl<O: 'static> Graph<O, Synchronous> {
.enumerate()
.find(|(_, idx)| **idx == idx_to_remove)
.expect("removed node must have been already added");
assert!(
index_to_remove_in_sorted > i,
"cannot remove already evaluated node"
);
graph.remove_node(idx_to_remove);
self.sorted_nodes.remove(index_to_remove_in_sorted);
nodes_changed = true;
@ -311,7 +302,7 @@ impl<O: 'static> Graph<O, Synchronous> {
// Invalidate any downstream nodes (which we know we haven't visited yet, because
// we're iterating over a topological sort of the graph).
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.edges_directed(current_idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target())
// Need to collect because the edges_directed iterator borrows the graph, and
// we need to mutably borrow to invalidate.
@ -323,6 +314,36 @@ impl<O: 'static> Graph<O, Synchronous> {
}
if nodes_changed {
UpdateStepResult::Restart
} else {
UpdateStepResult::Continue
}
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum UpdateStepResult {
Continue,
Restart,
}
impl<O: 'static> Graph<O, Synchronous> {
fn update_invalid_nodes(&mut self) {
let mut graph = self.node_graph.borrow_mut();
let mut i = 0;
while i < self.sorted_nodes.len() {
let idx = self.sorted_nodes[i];
let node = &mut graph[idx];
if !node.is_valid() {
// Update this node
let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx);
drop(graph);
let result = self.process_update_step(idx, ctx);
graph = self.node_graph.borrow_mut();
if result == UpdateStepResult::Restart {
// If we added/removed nodes, the sorted order has changed, so start evaluating
// from the beginning, in case of changes before i.
i = 0;
@ -371,57 +392,11 @@ impl<O: 'static> Graph<O, Asynchronous> {
let mut ctx = NodeUpdateContext::new();
node.update(&mut ctx).await;
let mut nodes_changed = false;
for idx_to_remove in ctx.removed_nodes {
assert!(
idx_to_remove != idx,
"cannot remove node curently being evaluated"
);
let (index_to_remove_in_sorted, _) = self
.sorted_nodes
.iter()
.enumerate()
.find(|(_, idx)| **idx == idx_to_remove)
.expect("removed node must have been already added");
assert!(
index_to_remove_in_sorted > i,
"cannot remove already evaluated node"
);
graph.remove_node(idx_to_remove);
self.sorted_nodes.remove(index_to_remove_in_sorted);
nodes_changed = true;
}
for (added_node, id_cell) in ctx.added_nodes {
let id = graph.add_node(added_node);
id_cell.set(Some(id));
nodes_changed = true;
}
if nodes_changed {
// Update the graph before invalidating downstream nodes.
drop(graph);
self._modify(|_| {})
.expect("modifying graph during evaluation must produce valid graph");
let result = self.process_update_step(idx, ctx);
graph = self.node_graph.borrow_mut();
}
if ctx.invalidate_dependent_nodes {
// Invalidate any downstream nodes (which we know we haven't visited yet, because
// we're iterating over a topological sort of the graph).
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target())
// Need to collect because the edges_directed iterator borrows the graph, and
// we need to mutably borrow to invalidate.
.collect::<Vec<_>>();
for dependent_idx in dependents {
let dependent = &mut graph[dependent_idx];
dependent.invalidate();
}
}
if nodes_changed {
if result == UpdateStepResult::Restart {
// If we added/removed nodes, the sorted order has changed, so start evaluating
// from the beginning, in case of changes before i.
i = 0;
@ -433,7 +408,7 @@ impl<O: 'static> Graph<O, Asynchronous> {
}
// Consistency check: after updating in the topological sort order, we should be left with
// no invalid nodes
// no invalid nodes.
debug_assert!(self
.sorted_nodes
.iter()
@ -530,6 +505,8 @@ impl<V> Clone for ValueInvalidationSignal<V> {
#[cfg(test)]
mod tests {
use rule::DynamicNodeFactory;
use super::*;
use crate::rule::{
AsyncDynamicRule, AsyncRule, ConstantRule, DynamicInput, DynamicRule, InputVisitable, Rule,
@ -838,10 +815,13 @@ mod tests {
fn dynamic_rule() {
let mut builder = GraphBuilder::new();
let (count, set_count) = builder.add_invalidatable_value(1);
struct CountUpTo(Input<i32>, Vec<Input<i32>>);
struct CountUpTo {
count: Input<i32>,
node_factory: DynamicNodeFactory<i32, i32>,
}
impl InputVisitable for CountUpTo {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
visitor.visit(&self.count);
}
}
impl DynamicRule for CountUpTo {
@ -850,16 +830,17 @@ mod tests {
&mut self,
ctx: &mut impl rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
let count = *self.0.value();
assert!(count >= self.1.len() as i32);
while (self.1.len() as i32) < count {
self.1
.push(ctx.add_rule(ConstantRule::new(self.1.len() as i32 + 1)));
let count = *self.count.value();
for i in 1..=count {
self.node_factory.add_rule(ctx, i, || ConstantRule::new(i));
}
self.1.clone()
self.node_factory.all_nodes(ctx)
}
}
let all_inputs = builder.add_dynamic_rule(CountUpTo(count, vec![]));
let all_inputs = builder.add_dynamic_rule(CountUpTo {
count,
node_factory: DynamicNodeFactory::new(),
});
struct Sum(DynamicInput<i32>);
impl InputVisitable for Sum {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
@ -879,6 +860,8 @@ mod tests {
assert_eq!(*graph.evaluate(), 3);
set_count.set_value(4);
assert_eq!(*graph.evaluate(), 10);
set_count.set_value(2);
assert_eq!(*graph.evaluate(), 3);
println!("{}", graph.as_dot_string());
}