Modify graph

This commit is contained in:
Shadowfacts 2024-10-29 14:18:53 -04:00
parent 67ddf2f254
commit 29838e2113

View File

@ -1,12 +1,9 @@
use petgraph::{ use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
graph::{DiGraph, NodeIndex},
visit::EdgeRef,
};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
use std::rc::Rc; use std::rc::Rc;
use std::{any::Any, collections::VecDeque}; use std::{any::Any, collections::VecDeque};
type NodeGraph = DiGraph<ErasedNode, (), u32>; type NodeGraph = StableDiGraph<ErasedNode, (), u32>;
pub struct Graph<Output> { pub struct Graph<Output> {
// we treat this as a StableGraph, since nodes are never removed // we treat this as a StableGraph, since nodes are never removed
@ -18,14 +15,13 @@ pub struct Graph<Output> {
impl<Output: Clone + 'static> Graph<Output> { impl<Output: Clone + 'static> Graph<Output> {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
node_graph: Rc::new(RefCell::new(DiGraph::new())), node_graph: Rc::new(RefCell::new(StableDiGraph::new())),
output: None, output: None,
output_type: std::marker::PhantomData, output_type: std::marker::PhantomData,
} }
} }
pub fn set_output<R: Rule<Output> + 'static>(&mut self, rule: R) { pub fn set_output<R: Rule<Output> + 'static>(&mut self, rule: R) {
assert!(self.output.is_none(), "cannot replace graph output");
let input = self.add_rule(rule); let input = self.add_rule(rule);
self.output = Some(input.node_idx); self.output = Some(input.node_idx);
} }
@ -70,7 +66,7 @@ impl<Output: Clone + 'static> Graph<Output> {
} }
let graph = self.node_graph.borrow(); let graph = self.node_graph.borrow();
let indices = graph.node_indices(); let indices = graph.node_indices().collect::<Vec<_>>();
drop(graph); drop(graph);
let mut edges = vec![]; let mut edges = vec![];
for idx in indices { for idx in indices {
@ -84,6 +80,8 @@ impl<Output: Clone + 'static> Graph<Output> {
self.node_graph.borrow_mut().add_edge(source, dest, ()); self.node_graph.borrow_mut().add_edge(source, dest, ());
} }
// TODO: remove nodes not connected to output
let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None); let sorted = petgraph::algo::toposort(&*self.node_graph.borrow(), None);
if let Err(_cycle) = sorted { if let Err(_cycle) = sorted {
// TODO: actually build a vec describing the cycle path for debugging // TODO: actually build a vec describing the cycle path for debugging
@ -148,6 +146,30 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
let node = &graph[self.output].expect_type::<Output>(); let node = &graph[self.output].expect_type::<Output>();
node.value() node.value()
} }
pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError>
where
F: FnMut(&mut Graph<Output>) -> (),
{
// Clear the edges before modifying so that re-freezing results in a graph with up-to-date edges.
let mut graph = self.node_graph.borrow_mut();
graph.clear_edges();
drop(graph);
let mut graph = Graph {
node_graph: Rc::clone(&self.node_graph),
output: Some(self.output),
output_type: std::marker::PhantomData,
};
f(&mut graph);
match graph.freeze() {
Ok(g) => {
*self = g;
Ok(())
}
Err(e) => Err(e),
}
}
} }
pub struct Input<T> { pub struct Input<T> {
@ -484,4 +506,18 @@ mod tests {
Ok(_) => assert!(false, "shouldn't have frozen graph"), Ok(_) => assert!(false, "shouldn't have frozen graph"),
} }
} }
#[test]
fn modify_graph() {
let mut graph = Graph::new();
graph.set_output(ConstantRule(1));
let mut frozen = graph.freeze().unwrap();
assert_eq!(frozen.evaluate(), 1);
frozen
.modify(|g| {
g.set_output(ConstantRule(2));
})
.expect("modify");
assert_eq!(frozen.evaluate(), 2);
}
} }