Invalidate necessary parts of the graph after modification

This commit is contained in:
Shadowfacts 2024-10-30 00:08:27 -04:00
parent 67fb9db461
commit 81cd986f77

View File

@ -1,7 +1,11 @@
#![feature(let_chains)]
mod util; mod util;
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef}; use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::collections::HashMap;
use std::rc::Rc; use std::rc::Rc;
use std::{any::Any, collections::VecDeque}; use std::{any::Any, collections::VecDeque};
@ -144,6 +148,12 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
} }
} }
pub fn is_output_valid(&self) -> bool {
let graph = self.node_graph.borrow();
let node = &graph[self.output];
(node.is_valid)(&node.any)
}
pub fn evaluate(&mut self) -> Output { pub fn evaluate(&mut self) -> Output {
self.update_node(self.output); self.update_node(self.output);
let graph = self.node_graph.borrow(); let graph = self.node_graph.borrow();
@ -159,6 +169,17 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
where where
F: FnMut(&mut Graph<Output>) -> (), F: FnMut(&mut Graph<Output>) -> (),
{ {
// Copy all the current edges so we can check if any change.
let graph = self.node_graph.borrow();
let mut old_edges = HashMap::new();
for edge in graph.edge_references() {
old_edges
.entry(graph.to_index(edge.source()))
.or_insert(vec![])
.push(graph.to_index(edge.target()));
}
drop(graph);
// Clear the edges before modifying so that re-freezing results in a graph with up-to-date edges. // Clear the edges before modifying so that re-freezing results in a graph with up-to-date edges.
let mut graph = self.node_graph.borrow_mut(); let mut graph = self.node_graph.borrow_mut();
graph.clear_edges(); graph.clear_edges();
@ -170,17 +191,28 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
output_type: std::marker::PhantomData, output_type: std::marker::PhantomData,
}; };
f(&mut graph); f(&mut graph);
match graph.freeze() { *self = graph.freeze()?;
Ok(g) => {
*self = g; // Any new inboud edges invalidate their target nodes.
Ok(()) let mut graph = self.node_graph.borrow_mut();
let mut to_invalidate = VecDeque::new();
for edge in graph.edge_references() {
let source = graph.to_index(edge.source());
let target = graph.to_index(edge.target());
if !old_edges
.get(&source)
.map_or(false, |old| !old.contains(&target))
{
to_invalidate.push_back(edge.target());
} }
Err(e) => Err(e),
} }
invalidate_nodes(&mut graph, to_invalidate);
Ok(())
} }
} }
#[derive(Clone)] #[derive(Clone, Debug)]
pub struct Input<T> { pub struct Input<T> {
node_idx: NodeIndex<u32>, node_idx: NodeIndex<u32>,
value: Rc<RefCell<Option<T>>>, value: Rc<RefCell<Option<T>>>,
@ -204,18 +236,21 @@ pub struct InvalidationSignal {
impl InvalidationSignal { impl InvalidationSignal {
pub fn invalidate(&self) { pub fn invalidate(&self) {
let mut graph = self.graph.borrow_mut();
let mut queue = VecDeque::new(); let mut queue = VecDeque::new();
queue.push_back(self.node_idx.get().unwrap()); queue.push_back(self.node_idx.get().unwrap());
while let Some(idx) = queue.pop_front() { invalidate_nodes(&mut *self.graph.borrow_mut(), queue);
let node = &mut graph[idx]; }
if (node.is_valid)(&node.any) { }
(node.invalidate)(&mut node.any);
let dependents = graph fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque<NodeIndex<u32>>) {
.edges_directed(idx, petgraph::Direction::Outgoing) while let Some(idx) = queue.pop_front() {
.map(|edge| edge.target()); let node = &mut graph[idx];
queue.extend(dependents); if (node.is_valid)(&node.any) {
} (node.invalidate)(&mut node.any);
let dependents = graph
.edges_directed(idx, petgraph::Direction::Outgoing)
.map(|edge| edge.target());
queue.extend(dependents);
} }
} }
} }
@ -375,6 +410,16 @@ mod tests {
assert_eq!(graph.freeze().unwrap().evaluate(), 1234); assert_eq!(graph.freeze().unwrap().evaluate(), 1234);
} }
#[test]
fn test_output_is_valid() {
let mut graph = Graph::new();
graph.set_output(ConstantRule(1));
let mut frozen = graph.freeze().unwrap();
assert!(!frozen.is_output_valid());
frozen.evaluate();
assert!(frozen.is_output_valid());
}
struct Double(Input<i32>); struct Double(Input<i32>);
impl Rule<i32> for Double { impl Rule<i32> for Double {
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) { fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
@ -475,20 +520,21 @@ mod tests {
} }
} }
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
impl Rule<i32> for DeferredInput {
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
let mut borrowed = self.0.borrow_mut();
let input = borrowed.as_mut().unwrap();
visitor.visit(input);
}
fn evaluate(&mut self) -> i32 {
self.0.borrow().as_ref().unwrap().value()
}
}
#[test] #[test]
fn cant_freeze_cycle() { fn cant_freeze_cycle() {
let mut graph = Graph::new(); let mut graph = Graph::new();
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
impl Rule<i32> for DeferredInput {
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
let mut borrowed = self.0.borrow_mut();
let input = borrowed.as_mut().unwrap();
visitor.visit(input);
}
fn evaluate(&mut self) -> i32 {
0
}
}
let a_input = Rc::new(RefCell::new(None)); let a_input = Rc::new(RefCell::new(None));
let a = graph.add_rule(DeferredInput(Rc::clone(&a_input))); let a = graph.add_rule(DeferredInput(Rc::clone(&a_input)));
let b_input = Rc::new(RefCell::new(Some(a))); let b_input = Rc::new(RefCell::new(Some(a)));
@ -516,4 +562,21 @@ mod tests {
assert_eq!(frozen.evaluate(), 2); assert_eq!(frozen.evaluate(), 2);
assert_eq!(frozen.node_count(), 1); assert_eq!(frozen.node_count(), 1);
} }
#[test]
fn modify_with_dependencies() {
let mut graph = Graph::new();
let input = Rc::new(RefCell::new(None));
graph.set_output(DeferredInput(Rc::clone(&input)));
*input.borrow_mut() = Some(graph.add_value(1));
let mut frozen = graph.freeze().unwrap();
assert_eq!(frozen.evaluate(), 1);
frozen
.modify(|g| {
*input.borrow_mut() = Some(g.add_value(2));
})
.expect("modify");
assert!(!frozen.is_output_valid());
assert_eq!(frozen.evaluate(), 2);
}
} }