Walk topological sort to update nodes
This commit is contained in:
parent
a6e94340ee
commit
de025dc138
@ -11,6 +11,7 @@ pub struct GraphBuilder<Output, Synch: Synchronicity> {
|
||||
pub(crate) node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
pub(crate) output: Option<Input<Output>>,
|
||||
pub(crate) output_type: std::marker::PhantomData<Output>,
|
||||
pub(crate) is_valid: Rc<Cell<bool>>,
|
||||
}
|
||||
|
||||
impl<O: 'static> GraphBuilder<O, Synchronous> {
|
||||
@ -19,6 +20,7 @@ impl<O: 'static> GraphBuilder<O, Synchronous> {
|
||||
node_graph: Rc::new(RefCell::new(NodeGraph::new())),
|
||||
output: None,
|
||||
output_type: std::marker::PhantomData,
|
||||
is_valid: Rc::new(Cell::new(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -29,6 +31,7 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
node_graph: Rc::new(RefCell::new(NodeGraph::new())),
|
||||
output: None,
|
||||
output_type: std::marker::PhantomData,
|
||||
is_valid: Rc::new(Cell::new(false)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -69,6 +72,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
let signal = InvalidationSignal {
|
||||
node_idx: Rc::clone(&node_idx),
|
||||
graph: Rc::clone(&self.node_graph),
|
||||
graph_is_valid: Rc::clone(&self.is_valid),
|
||||
};
|
||||
let input = self.add_rule(f(signal));
|
||||
node_idx.set(Some(input.node_idx));
|
||||
@ -100,17 +104,18 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
util::remove_nodes_not_connected_to(&mut *graph, output.node_idx);
|
||||
drop(graph);
|
||||
|
||||
let sorted = petgraph::algo::toposort(&**self.node_graph.borrow(), None);
|
||||
if let Err(_cycle) = sorted {
|
||||
self.node_graph.borrow_mut().clear_edges();
|
||||
// TODO: actually build a vec describing the cycle path for debugging
|
||||
return Err(BuildGraphError::Cyclic(vec![]));
|
||||
}
|
||||
let sorted_nodes =
|
||||
petgraph::algo::toposort(&**self.node_graph.borrow(), None).map_err(|_| {
|
||||
// TODO: actually build a vec describing the cycle path for debugging
|
||||
BuildGraphError::Cyclic(vec![])
|
||||
})?;
|
||||
|
||||
let graph = Graph {
|
||||
node_graph: self.node_graph,
|
||||
output: self.output.unwrap(),
|
||||
output_type: std::marker::PhantomData,
|
||||
sorted_nodes,
|
||||
is_valid: self.is_valid,
|
||||
};
|
||||
|
||||
Ok(graph)
|
||||
@ -139,6 +144,7 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
let signal = InvalidationSignal {
|
||||
node_idx: Rc::clone(&node_idx),
|
||||
graph: Rc::clone(&self.node_graph),
|
||||
graph_is_valid: Rc::clone(&self.is_valid),
|
||||
};
|
||||
let input = self.add_async_rule(f(signal));
|
||||
node_idx.set(Some(input.node_idx));
|
||||
|
@ -42,13 +42,16 @@ pub struct Graph<Output, Synch: Synchronicity> {
|
||||
node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
output: Input<Output>,
|
||||
output_type: std::marker::PhantomData<Output>,
|
||||
// The topological sort of nodes in the graph.
|
||||
sorted_nodes: Vec<NodeId>,
|
||||
is_valid: Rc<Cell<bool>>,
|
||||
}
|
||||
|
||||
impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
||||
pub fn is_output_valid(&self) -> bool {
|
||||
let graph = self.node_graph.borrow();
|
||||
let node = &graph[self.output.node_idx];
|
||||
node.is_valid()
|
||||
self.is_valid.get() && node.is_valid()
|
||||
}
|
||||
|
||||
pub fn node_count(&self) -> usize {
|
||||
@ -75,10 +78,14 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
||||
graph.clear_edges();
|
||||
drop(graph);
|
||||
|
||||
let old_output = self.output.node_idx;
|
||||
|
||||
let mut graph = GraphBuilder {
|
||||
// TODO: is using the same graph as self correct? if the modify fails, is it left in a bad state?
|
||||
node_graph: Rc::clone(&self.node_graph),
|
||||
output: Some(self.output.clone()),
|
||||
output_type: std::marker::PhantomData,
|
||||
is_valid: Rc::clone(&self.is_valid),
|
||||
};
|
||||
f(&mut graph);
|
||||
*self = graph.build()?;
|
||||
@ -96,67 +103,91 @@ impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
||||
to_invalidate.push_back(edge.target());
|
||||
}
|
||||
}
|
||||
invalidate_nodes::<S>(&mut graph, to_invalidate);
|
||||
// Edge case: if the only node in the graph is the output node, and it's replaced in the modify block,
|
||||
// there are no edges but we still need to invalidate.
|
||||
if !to_invalidate.is_empty() || self.output.node_idx != old_output {
|
||||
self.is_valid.set(false);
|
||||
for idx in to_invalidate {
|
||||
let node = &mut graph[idx];
|
||||
node.invalidate();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: 'static> Graph<O, Synchronous> {
|
||||
fn update_node(&mut self, idx: NodeId) {
|
||||
let graph = self.node_graph.borrow();
|
||||
let node = &graph[idx];
|
||||
if !node.is_valid() {
|
||||
// collect all the edges into a vec so that we can mutably borrow the graph to update the nodes
|
||||
let edge_sources = graph
|
||||
.edges_directed(idx, petgraph::Direction::Incoming)
|
||||
.map(|edge| edge.source())
|
||||
.collect::<Vec<_>>();
|
||||
drop(graph);
|
||||
fn update_invalid_nodes(&mut self) {
|
||||
let mut graph = self.node_graph.borrow_mut();
|
||||
for &idx in self.sorted_nodes.iter() {
|
||||
let node = &mut graph[idx];
|
||||
if !node.is_valid() {
|
||||
// Update this node
|
||||
node.update();
|
||||
|
||||
// Update the dependencies of this node.
|
||||
// TODO: iterating/recursing here seems less than efficient
|
||||
// instead, in evaluate, topo sort the graph and update invalid nodes?
|
||||
for source in edge_sources {
|
||||
self.update_node(source);
|
||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
|
||||
let dependents = graph
|
||||
.edges_directed(idx, petgraph::Direction::Outgoing)
|
||||
.map(|edge| edge.target())
|
||||
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
|
||||
.collect::<Vec<_>>();
|
||||
for dependent_idx in dependents {
|
||||
let dependent = &mut graph[dependent_idx];
|
||||
dependent.invalidate();
|
||||
}
|
||||
}
|
||||
|
||||
let node = &mut self.node_graph.borrow_mut()[idx];
|
||||
// Actually update the node's value.
|
||||
node.update();
|
||||
}
|
||||
// Consistency check: after updating in the topological sort order, we should be left with no invalid nodes
|
||||
debug_assert!(self
|
||||
.sorted_nodes
|
||||
.iter()
|
||||
.all(|&idx| { (&graph[idx]).is_valid() }));
|
||||
self.is_valid.set(true);
|
||||
}
|
||||
|
||||
pub fn evaluate(&mut self) -> impl Deref<Target = O> + '_ {
|
||||
self.update_node(self.output.node_idx);
|
||||
if !self.is_valid.get() {
|
||||
self.update_invalid_nodes();
|
||||
}
|
||||
self.output.value()
|
||||
}
|
||||
}
|
||||
|
||||
impl<O: 'static> Graph<O, Asynchronous> {
|
||||
async fn update_node_async(&mut self, idx: NodeId) {
|
||||
// TODO: same note about recursing as above, and consider doing this in parallel
|
||||
let graph = self.node_graph.borrow();
|
||||
let node = &graph[idx];
|
||||
if !node.is_valid() {
|
||||
let edge_sources = graph
|
||||
.edges_directed(idx, petgraph::Direction::Incoming)
|
||||
.map(|edge| edge.source())
|
||||
.collect::<Vec<_>>();
|
||||
drop(graph);
|
||||
|
||||
for source in edge_sources {
|
||||
Box::pin(self.update_node_async(source)).await;
|
||||
}
|
||||
|
||||
let mut graph = self.node_graph.borrow_mut();
|
||||
async fn update_invalid_nodes(&mut self) {
|
||||
// TODO: consider whether this can be done in parallel to any degree.
|
||||
let mut graph = self.node_graph.borrow_mut();
|
||||
for &idx in self.sorted_nodes.iter() {
|
||||
let node = &mut graph[idx];
|
||||
node.update().await;
|
||||
if !node.is_valid() {
|
||||
// Update this node
|
||||
node.update().await;
|
||||
|
||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
|
||||
let dependents = graph
|
||||
.edges_directed(idx, petgraph::Direction::Outgoing)
|
||||
.map(|edge| edge.target())
|
||||
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
|
||||
.collect::<Vec<_>>();
|
||||
for dependent_idx in dependents {
|
||||
let dependent = &mut graph[dependent_idx];
|
||||
dependent.invalidate();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Consistency check: after updating in the topological sort order, we should be left with no invalid nodes
|
||||
debug_assert!(self
|
||||
.sorted_nodes
|
||||
.iter()
|
||||
.all(|&idx| { (&graph[idx]).is_valid() }));
|
||||
self.is_valid.set(true);
|
||||
}
|
||||
|
||||
pub async fn evaluate_async(&mut self) -> impl Deref<Target = O> + '_ {
|
||||
self.update_node_async(self.output.node_idx).await;
|
||||
if !self.is_valid.get() {
|
||||
self.update_invalid_nodes().await;
|
||||
}
|
||||
self.output.value()
|
||||
}
|
||||
}
|
||||
@ -190,26 +221,15 @@ impl<T> Clone for Input<T> {
|
||||
pub struct InvalidationSignal<Synch: Synchronicity> {
|
||||
node_idx: Rc<Cell<Option<NodeId>>>,
|
||||
graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
graph_is_valid: Rc<Cell<bool>>,
|
||||
}
|
||||
|
||||
impl<S: Synchronicity> InvalidationSignal<S> {
|
||||
pub fn invalidate(&self) {
|
||||
let mut queue = VecDeque::new();
|
||||
queue.push_back(self.node_idx.get().unwrap());
|
||||
invalidate_nodes::<S>(&mut *self.graph.borrow_mut(), queue);
|
||||
}
|
||||
}
|
||||
|
||||
fn invalidate_nodes<S: Synchronicity>(graph: &mut NodeGraph<S>, mut queue: VecDeque<NodeId>) {
|
||||
while let Some(idx) = queue.pop_front() {
|
||||
let node = &mut graph[idx];
|
||||
if node.is_valid() {
|
||||
node.invalidate();
|
||||
let dependents = graph
|
||||
.edges_directed(idx, petgraph::Direction::Outgoing)
|
||||
.map(|edge| edge.target());
|
||||
queue.extend(dependents);
|
||||
}
|
||||
self.graph_is_valid.set(false);
|
||||
let mut graph = self.graph.borrow_mut();
|
||||
let node = &mut graph[self.node_idx.get().unwrap()];
|
||||
node.invalidate();
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user