650 lines
22 KiB
Rust
650 lines
22 KiB
Rust
//! Facilities for using a directed, acyclic graph to perform computation.
|
|
//!
|
|
//! A directed, acyclic graph (DAG) can be used to carry out computations by considering
|
|
//! each node to have a value and each edge to represent a dependency on the value of one
|
|
//! node to compute the value of another node. A node's value can either be constant or be
|
|
//! produced by a rule, which is a piece of code for generating the value of a node given its
|
|
//! dependencies. For example, an arithmetic operation can be implemented like so:
|
|
//!
|
|
//! ```rust
|
|
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
|
//! let mut builder = GraphBuilder::new();
|
|
//! let a = builder.add_value(1);
|
|
//! let b = builder.add_value(2);
|
|
//! # #[derive(InputVisitable)]
|
|
//! # struct Add(Input<i32>, Input<i32>);
|
|
//! # impl Rule for Add {
|
|
//! # type Output = i32;
|
|
//! # fn evaluate(&mut self) -> i32 {
|
|
//! # *self.input_0() + *self.input_1()
|
|
//! # }
|
|
//! # }
|
|
//! builder.set_output(Add(a, b));
|
|
//!
|
|
//! let mut graph = builder.build().unwrap();
|
|
//! assert_eq!(*graph.evaluate(), 3);
|
|
//! ```
|
|
//!
|
|
//! Here, `a` and `b` are placeholders representing the values of the two constant nodes in the graph.
|
|
//! The `Add` struct implements the [`Rule`] trait and defines how to combine those two values by addition.
|
|
//! The `Add` rule is implemented as follows:
|
|
//!
|
|
//! ```rust
|
|
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
|
//! #[derive(InputVisitable)]
|
|
//! struct Add(Input<i32>, Input<i32>);
|
|
//!
|
|
//! impl Rule for Add {
|
|
//! type Output = i32;
|
|
//! fn evaluate(&mut self) -> i32 {
|
|
//! *self.input_0() + *self.input_1()
|
|
//! }
|
|
//! }
|
|
//! ```
|
|
|
|
pub mod builder;
|
|
pub mod node;
|
|
pub mod rule;
|
|
pub mod synchronicity;
|
|
mod util;
|
|
|
|
use builder::{BuildGraphError, GraphBuilder};
|
|
use node::{ErasedNode, NodeValue};
|
|
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
|
|
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
|
|
use rule::{AsyncRule, Input, InputVisitor, Rule};
|
|
use std::cell::{Cell, RefCell};
|
|
use std::collections::HashMap;
|
|
use std::collections::VecDeque;
|
|
use std::ops::{Deref, DerefMut};
|
|
use std::rc::Rc;
|
|
use synchronicity::*;
|
|
|
|
// use a struct for this, not a type alias, because generic bounds of type aliases aren't enforced
|
|
struct NodeGraph<S: Synchronicity>(StableDiGraph<ErasedNode<S>, (), u32>);
|
|
type NodeId = petgraph::stable_graph::NodeIndex<u32>;
|
|
|
|
impl<S: Synchronicity> NodeGraph<S> {
|
|
fn new() -> Self {
|
|
Self(StableDiGraph::new())
|
|
}
|
|
}
|
|
|
|
impl<S: Synchronicity> Deref for NodeGraph<S> {
|
|
type Target = StableDiGraph<ErasedNode<S>, (), u32>;
|
|
|
|
fn deref(&self) -> &Self::Target {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl<S: Synchronicity> DerefMut for NodeGraph<S> {
|
|
fn deref_mut(&mut self) -> &mut Self::Target {
|
|
&mut self.0
|
|
}
|
|
}
|
|
|
|
/// A constructed graph that can evaluated.
|
|
///
|
|
/// Use [`GraphBuilder`] to construct a graph.
|
|
///
|
|
/// The graph is generic over the type of the output node's value and the [`Synchronicity`]
|
|
/// —that is, whether it can be evaluated synchronously or asynchronously.
|
|
pub struct Graph<Output, Synch: Synchronicity> {
|
|
node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
|
output: Input<Output>,
|
|
output_type: std::marker::PhantomData<Output>,
|
|
// The topological sort of nodes in the graph.
|
|
sorted_nodes: Vec<NodeId>,
|
|
is_valid: Rc<Cell<bool>>,
|
|
}
|
|
|
|
/// A synchronous graph, containing only sync nodes.
|
|
pub type SyncGraph<Output> = Graph<Output, Synchronous>;
|
|
|
|
/// An asynchronous graph, containing a mix of sync and async nodes.
|
|
pub type AsyncGraph<Output> = Graph<Output, Asynchronous>;
|
|
|
|
impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
|
/// Whether the output value of the graph is currently valid.
|
|
///
|
|
/// The output is considered presumptively invalid if _any_ of the nodes in the graph are invalid,
|
|
/// even if, when evaluated, the invalid node's value is unchanged (in which case, downstream nodes
|
|
/// are not invalidated) and the output may be unchanged.
|
|
pub fn is_output_valid(&self) -> bool {
|
|
self.is_valid.get()
|
|
}
|
|
|
|
/// The number of nodes in the graph.
|
|
pub fn node_count(&self) -> usize {
|
|
self.node_graph.borrow().node_count()
|
|
}
|
|
|
|
/// Modify the graph using the given function.
|
|
///
|
|
/// The function receives as its parameter a [`GraphBuilder`] representing the current graph.
|
|
///
|
|
/// Because building a graph can fail and this method mutates the underlying graph, it takes
|
|
/// ownership of the current graph to prevent the graph being left in an invalid state.
|
|
/// It returns either the new, modified graph or an error.
|
|
pub fn modify<F>(mut self, mut f: F) -> Result<Self, BuildGraphError>
|
|
where
|
|
F: FnMut(&mut GraphBuilder<O, S>) -> (),
|
|
{
|
|
// Copy all the current edges so we can check if any change.
|
|
let graph = self.node_graph.borrow();
|
|
let mut old_edges = HashMap::new();
|
|
for edge in graph.edge_references() {
|
|
old_edges
|
|
.entry(graph.to_index(edge.source()))
|
|
.or_insert(vec![])
|
|
.push(graph.to_index(edge.target()));
|
|
}
|
|
drop(graph);
|
|
|
|
let old_output = self.output.node_idx;
|
|
|
|
// Modify
|
|
let mut builder = self.into_builder();
|
|
f(&mut builder);
|
|
self = builder.build()?;
|
|
|
|
// Any new inboud edges invalidate their target nodes.
|
|
let mut graph = self.node_graph.borrow_mut();
|
|
let mut to_invalidate = VecDeque::new();
|
|
for edge in graph.edge_references() {
|
|
let source = graph.to_index(edge.source());
|
|
let target = graph.to_index(edge.target());
|
|
if !old_edges
|
|
.get(&source)
|
|
.map_or(false, |old| !old.contains(&target))
|
|
{
|
|
to_invalidate.push_back(edge.target());
|
|
}
|
|
}
|
|
// Edge case: if the only node in the graph is the output node, and it's replaced in the modify block,
|
|
// there are no edges but we still need to invalidate.
|
|
if !to_invalidate.is_empty() || self.output.node_idx != old_output {
|
|
self.is_valid.set(false);
|
|
for idx in to_invalidate {
|
|
let node = &mut graph[idx];
|
|
node.invalidate();
|
|
}
|
|
}
|
|
|
|
drop(graph);
|
|
Ok(self)
|
|
}
|
|
|
|
/// Convert this graph back into a builder for further modifications.
|
|
///
|
|
/// Returns a builder with the same output and synchronicity types.
|
|
pub fn into_builder(self) -> GraphBuilder<O, S> {
|
|
// Clear the edges before modifying so that rebuilding results in a graph with up-to-date edges.
|
|
let mut graph = self.node_graph.borrow_mut();
|
|
graph.clear_edges();
|
|
drop(graph);
|
|
|
|
GraphBuilder {
|
|
node_graph: Rc::clone(&self.node_graph),
|
|
output: Some(self.output.clone()),
|
|
output_type: std::marker::PhantomData,
|
|
is_valid: Rc::clone(&self.is_valid),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<O: 'static> Graph<O, Synchronous> {
|
|
fn update_invalid_nodes(&mut self) {
|
|
let mut graph = self.node_graph.borrow_mut();
|
|
for &idx in self.sorted_nodes.iter() {
|
|
let node = &mut graph[idx];
|
|
if !node.is_valid() {
|
|
// Update this node
|
|
let value_changed = node.update();
|
|
|
|
if value_changed {
|
|
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
|
// we're iterating over a topological sort of the graph).
|
|
let dependents = graph
|
|
.edges_directed(idx, petgraph::Direction::Outgoing)
|
|
.map(|edge| edge.target())
|
|
// Need to collect because the edges_directed iterator borrows the graph, and
|
|
// we need to mutably borrow to invalidate.
|
|
.collect::<Vec<_>>();
|
|
for dependent_idx in dependents {
|
|
let dependent = &mut graph[dependent_idx];
|
|
dependent.invalidate();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Consistency check: after updating in the topological sort order, we should be left with
|
|
// no invalid nodes
|
|
debug_assert!(self
|
|
.sorted_nodes
|
|
.iter()
|
|
.all(|&idx| { (&graph[idx]).is_valid() }));
|
|
self.is_valid.set(true);
|
|
}
|
|
|
|
/// Synchronously evaluate the graph and return a reference to the value of the output node.
|
|
///
|
|
/// If the graph is valid (see [`Graph::is_output_valid`]), this is a constant-time operation.
|
|
/// Otherwise, any invalid nodes and their downstream dependents will be updated, which is an
|
|
/// O(n) operation.
|
|
///
|
|
/// This method is only available on synchronous graphs, which can only contain synchronous nodes.
|
|
pub fn evaluate(&mut self) -> impl Deref<Target = O> + '_ {
|
|
if !self.is_valid.get() {
|
|
self.update_invalid_nodes();
|
|
}
|
|
self.output.value()
|
|
}
|
|
}
|
|
|
|
impl<O: 'static> Graph<O, Asynchronous> {
|
|
async fn update_invalid_nodes(&mut self) {
|
|
// TODO: consider whether this can be done in parallel to any degree.
|
|
let mut graph = self.node_graph.borrow_mut();
|
|
for &idx in self.sorted_nodes.iter() {
|
|
let node = &mut graph[idx];
|
|
if !node.is_valid() {
|
|
// Update this node
|
|
let value_changed = node.update().await;
|
|
|
|
if value_changed {
|
|
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
|
// we're iterating over a topological sort of the graph).
|
|
let dependents = graph
|
|
.edges_directed(idx, petgraph::Direction::Outgoing)
|
|
.map(|edge| edge.target())
|
|
// Need to collect because the edges_directed iterator borrows the graph, and
|
|
// we need to mutably borrow to invalidate.
|
|
.collect::<Vec<_>>();
|
|
for dependent_idx in dependents {
|
|
let dependent = &mut graph[dependent_idx];
|
|
dependent.invalidate();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
// Consistency check: after updating in the topological sort order, we should be left with
|
|
// no invalid nodes
|
|
debug_assert!(self
|
|
.sorted_nodes
|
|
.iter()
|
|
.all(|&idx| { (&graph[idx]).is_valid() }));
|
|
self.is_valid.set(true);
|
|
}
|
|
|
|
/// Asynchronously evaluate the graph and return a reference to the value of the output node.
|
|
///
|
|
/// If the graph is valid (see [`Graph::is_output_valid`]), this is a constant-time operation.
|
|
/// Otherwise, any invalid nodes and their downstream dependents will be updated, which is an
|
|
/// O(n) operation.
|
|
///
|
|
/// This method is only available on asynchronous graphs, which can contain a mix of asynchronous
|
|
/// and synchronous nodes.
|
|
pub async fn evaluate_async(&mut self) -> impl Deref<Target = O> + '_ {
|
|
if !self.is_valid.get() {
|
|
self.update_invalid_nodes().await;
|
|
}
|
|
self.output.value()
|
|
}
|
|
}
|
|
|
|
/// A type representing a node in a graph that can be invalidated due to external factors.
|
|
///
|
|
/// See [`GraphBuilder::add_invalidatable_rule`].
|
|
///
|
|
/// `InvalidationSignal` implements `Clone`, so the signal can be cloned and used from multiple places.
|
|
// TODO: there's a lot happening here, make sure this doesn't create a reference cycle
|
|
pub struct InvalidationSignal {
|
|
do_invalidate: Rc<Box<dyn Fn() -> ()>>,
|
|
}
|
|
|
|
impl InvalidationSignal {
|
|
/// Tell the graph that the node corresponding to this signal is now invalid.
|
|
///
|
|
/// Note: Calling this method does not trigger a graph evaluation, it merely marks the corresponding
|
|
/// node as invalid. The graph will not be re-evaluated until [`Graph::evaluate`] or
|
|
/// [`Graph::evaluate_async`] is next called.
|
|
pub fn invalidate(&self) {
|
|
(self.do_invalidate)();
|
|
}
|
|
}
|
|
|
|
impl Clone for InvalidationSignal {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
do_invalidate: Rc::clone(&self.do_invalidate),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A type representing a node with an externally injected value.
|
|
///
|
|
/// See [`GraphBuilder::add_invalidatable_value`].
|
|
pub struct ValueInvalidationSignal<V> {
|
|
input: Input<V>,
|
|
signal: InvalidationSignal,
|
|
}
|
|
|
|
impl<V: NodeValue> ValueInvalidationSignal<V> {
|
|
/// Get a reference to current value for the node corresponding to this signal.
|
|
pub fn value(&self) -> impl Deref<Target = V> + '_ {
|
|
self.input.value()
|
|
}
|
|
|
|
/// Set a new value for the node corresponding to this signal.
|
|
///
|
|
/// Note: Calling this method does not trigger a graph evaluation, it merely sets a new value
|
|
/// for the corresponding node. The graph will not be re-evaluated until [`Graph::evaluate`] or
|
|
/// [`Graph::evaluate_async`] is next called.
|
|
pub fn set_value(&self, value: V) {
|
|
let mut current_value = self.input.value.borrow_mut();
|
|
if !current_value
|
|
.as_ref()
|
|
.expect("invalidatable value node must be initialized with value")
|
|
.node_value_eq(&value)
|
|
{
|
|
*current_value = Some(value);
|
|
self.signal.invalidate();
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<V> Clone for ValueInvalidationSignal<V> {
|
|
fn clone(&self) -> Self {
|
|
Self {
|
|
input: self.input.clone(),
|
|
signal: self.signal.clone(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::rule::{ConstantRule, InputVisitable};
|
|
|
|
#[test]
|
|
fn rule_output_with_no_inputs() {
|
|
let mut builder = GraphBuilder::new();
|
|
builder.set_output(ConstantRule::new(1234));
|
|
assert_eq!(*builder.build().unwrap().evaluate(), 1234);
|
|
}
|
|
|
|
#[test]
|
|
fn test_output_is_valid() {
|
|
let mut builder = GraphBuilder::new();
|
|
builder.set_output(ConstantRule::new(1));
|
|
let mut graph = builder.build().unwrap();
|
|
assert!(!graph.is_output_valid());
|
|
graph.evaluate();
|
|
assert!(graph.is_output_valid());
|
|
}
|
|
|
|
struct Double(Input<i32>);
|
|
impl InputVisitable for Double {
|
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
visitor.visit(&self.0);
|
|
}
|
|
}
|
|
impl Rule for Double {
|
|
type Output = i32;
|
|
fn evaluate(&mut self) -> i32 {
|
|
*self.0.value() * 2
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn rule_with_input() {
|
|
let mut builder = GraphBuilder::new();
|
|
let input = builder.add_value(42);
|
|
builder.set_output(Double(input));
|
|
assert_eq!(*builder.build().unwrap().evaluate(), 84);
|
|
}
|
|
|
|
#[test]
|
|
fn rule_with_input_rule() {
|
|
let mut builder = GraphBuilder::new();
|
|
let input = builder.add_value(42);
|
|
let doubled = builder.add_rule(Double(input));
|
|
builder.set_output(Double(doubled));
|
|
assert_eq!(*builder.build().unwrap().evaluate(), 168);
|
|
}
|
|
|
|
struct Inc(i32);
|
|
impl InputVisitable for Inc {
|
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
}
|
|
impl Rule for Inc {
|
|
type Output = i32;
|
|
fn evaluate(&mut self) -> i32 {
|
|
self.0 += 1;
|
|
return self.0;
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn invalidatable_rule() {
|
|
let mut builder = GraphBuilder::new();
|
|
let (input, invalidate) = builder.add_invalidatable_rule(Inc(0));
|
|
builder.set_output(Double(input));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), 2);
|
|
invalidate.invalidate();
|
|
assert_eq!(*graph.evaluate(), 4);
|
|
assert_eq!(*graph.evaluate(), 4);
|
|
invalidate.invalidate();
|
|
assert_eq!(*graph.evaluate(), 6);
|
|
}
|
|
|
|
struct Add(Input<i32>, Input<i32>);
|
|
impl InputVisitable for Add {
|
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
visitor.visit(&self.0);
|
|
visitor.visit(&self.1);
|
|
}
|
|
}
|
|
impl Rule for Add {
|
|
type Output = i32;
|
|
fn evaluate(&mut self) -> i32 {
|
|
*self.0.value() + *self.1.value()
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn rule_with_multiple_inputs() {
|
|
let mut builder = GraphBuilder::new();
|
|
let a = builder.add_value(2);
|
|
let b = builder.add_value(3);
|
|
builder.set_output(Add(a, b));
|
|
assert_eq!(*builder.build().unwrap().evaluate(), 5);
|
|
}
|
|
|
|
#[test]
|
|
fn rule_with_invalidatable_inputs() {
|
|
let mut builder = GraphBuilder::new();
|
|
let (a, invalidate) = builder.add_invalidatable_rule(Inc(0));
|
|
let b = builder.add_rule(Inc(0));
|
|
builder.set_output(Add(a, b));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), 2);
|
|
invalidate.invalidate();
|
|
assert_eq!(*graph.evaluate(), 3);
|
|
assert_eq!(*graph.evaluate(), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn cant_freeze_no_output() {
|
|
let builder = GraphBuilder::<i32, Synchronous>::new();
|
|
match builder.build() {
|
|
Err(BuildGraphError::NoOutput) => (),
|
|
Err(e) => assert!(false, "unexpected error {:?}", e),
|
|
Ok(_) => assert!(false, "shouldn't have frozen graph"),
|
|
}
|
|
}
|
|
|
|
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
|
|
impl InputVisitable for DeferredInput {
|
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
let borrowed = self.0.borrow();
|
|
let input = borrowed.as_ref().unwrap();
|
|
visitor.visit(input);
|
|
}
|
|
}
|
|
impl Rule for DeferredInput {
|
|
type Output = i32;
|
|
fn evaluate(&mut self) -> i32 {
|
|
*self.0.borrow().as_ref().unwrap().value()
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn cant_freeze_cycle() {
|
|
let mut builder = GraphBuilder::new();
|
|
let a_input = Rc::new(RefCell::new(None));
|
|
let a = builder.add_rule(DeferredInput(Rc::clone(&a_input)));
|
|
let b_input = Rc::new(RefCell::new(Some(a)));
|
|
let b = builder.add_rule(DeferredInput(b_input));
|
|
*a_input.borrow_mut() = Some(b.clone());
|
|
builder.set_output(Double(b));
|
|
match builder.build() {
|
|
Err(BuildGraphError::Cycle(_)) => (),
|
|
Err(e) => assert!(false, "unexpected error {:?}", e),
|
|
Ok(_) => assert!(false, "shouldn't have frozen graph"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn modify_graph() {
|
|
let mut builder = GraphBuilder::new();
|
|
builder.set_output(ConstantRule::new(1));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), 1);
|
|
graph = graph
|
|
.modify(|g| {
|
|
g.set_output(ConstantRule::new(2));
|
|
})
|
|
.expect("modify");
|
|
assert_eq!(*graph.evaluate(), 2);
|
|
assert_eq!(graph.node_count(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn modify_with_dependencies() {
|
|
let mut builder = GraphBuilder::new();
|
|
let input = Rc::new(RefCell::new(None));
|
|
builder.set_output(DeferredInput(Rc::clone(&input)));
|
|
*input.borrow_mut() = Some(builder.add_value(1));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), 1);
|
|
graph = graph
|
|
.modify(|g| {
|
|
*input.borrow_mut() = Some(g.add_value(2));
|
|
})
|
|
.expect("modify");
|
|
assert!(!graph.is_output_valid());
|
|
assert_eq!(*graph.evaluate(), 2);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn async_graph() {
|
|
let mut builder = GraphBuilder::new_async();
|
|
builder.set_output(ConstantRule::new(42));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate_async().await, 42);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn async_rule() {
|
|
struct AsyncConst(i32);
|
|
impl InputVisitable for AsyncConst {
|
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
}
|
|
impl AsyncRule for AsyncConst {
|
|
type Output = i32;
|
|
async fn evaluate(&mut self) -> i32 {
|
|
self.0
|
|
}
|
|
}
|
|
let mut builder = GraphBuilder::new_async();
|
|
builder.set_async_output(AsyncConst(42));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate_async().await, 42);
|
|
}
|
|
|
|
#[test]
|
|
fn non_cloneable_output() {
|
|
#[derive(PartialEq, Debug)]
|
|
struct NonCloneable;
|
|
struct Output;
|
|
impl InputVisitable for Output {
|
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
}
|
|
impl Rule for Output {
|
|
type Output = NonCloneable;
|
|
fn evaluate(&mut self) -> Self::Output {
|
|
NonCloneable
|
|
}
|
|
}
|
|
let mut builder = GraphBuilder::new();
|
|
builder.set_output(Output);
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), NonCloneable);
|
|
}
|
|
|
|
#[test]
|
|
fn only_update_downstream_nodes_if_value_changes() {
|
|
let mut builder = GraphBuilder::new();
|
|
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
|
struct IncAdd(Input<i32>, i32);
|
|
impl InputVisitable for IncAdd {
|
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
visitor.visit(&self.0);
|
|
}
|
|
}
|
|
impl Rule for IncAdd {
|
|
type Output = i32;
|
|
fn evaluate(&mut self) -> Self::Output {
|
|
self.1 += 1;
|
|
*self.0.value() + self.1
|
|
}
|
|
}
|
|
builder.set_output(IncAdd(a, 0));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), 1);
|
|
|
|
// IncAdd should not be evaluated again, despite its input being invalidated, so the output should be unchanged
|
|
invalidate.invalidate();
|
|
assert!(!graph.is_output_valid());
|
|
assert_eq!(*graph.evaluate(), 1);
|
|
}
|
|
|
|
#[test]
|
|
fn invalidatable_value() {
|
|
let mut builder = GraphBuilder::new();
|
|
let (a, invalidate) = builder.add_invalidatable_value(0);
|
|
let b = builder.add_value(1);
|
|
builder.set_output(Add(a, b));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate(), 1);
|
|
invalidate.set_value(42);
|
|
assert!(!graph.is_output_valid());
|
|
assert_eq!(*graph.evaluate(), 43);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn async_value() {
|
|
let mut builder = GraphBuilder::new_async();
|
|
let a = builder.add_async_value(|| async { 42 });
|
|
let b = builder.add_value(1);
|
|
builder.set_output(Add(a, b));
|
|
let mut graph = builder.build().unwrap();
|
|
assert_eq!(*graph.evaluate_async().await, 43);
|
|
}
|
|
}
|