Document all the things
This commit is contained in:
parent
c1c594d4f7
commit
88dfef75fd
@ -8,6 +8,12 @@ use crate::{
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::rc::Rc;
|
||||
|
||||
/// Builds a [`Graph`].
|
||||
///
|
||||
/// The builder is generic over the type of the output value and the synchronicity of the graph.
|
||||
/// Use [`GraphBuilder::new`] or [`GraphBuilder::new_async`] to create a builder.
|
||||
/// Use [`GraphBuilder::set_output`] (or [`GraphBuilder::set_async_output`]) to set the rule for
|
||||
/// the output node.
|
||||
pub struct GraphBuilder<Output, Synch: Synchronicity> {
|
||||
pub(crate) node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
pub(crate) output: Option<Input<Output>>,
|
||||
@ -16,6 +22,7 @@ pub struct GraphBuilder<Output, Synch: Synchronicity> {
|
||||
}
|
||||
|
||||
impl<O: 'static> GraphBuilder<O, Synchronous> {
|
||||
/// Creates a builder for a synchronous graph.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
node_graph: Rc::new(RefCell::new(NodeGraph::new())),
|
||||
@ -27,6 +34,7 @@ impl<O: 'static> GraphBuilder<O, Synchronous> {
|
||||
}
|
||||
|
||||
impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
/// Creates a builder for an asynchronous graph.
|
||||
pub fn new_async() -> Self {
|
||||
Self {
|
||||
node_graph: Rc::new(RefCell::new(NodeGraph::new())),
|
||||
@ -38,6 +46,9 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
}
|
||||
|
||||
impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
/// Sets a synchronous rule for the output node.
|
||||
///
|
||||
/// The type of the output rule's value is the type of the output value of the overall graph.
|
||||
pub fn set_output<R: Rule<Output = O>>(&mut self, rule: R) {
|
||||
let input = self.add_rule(rule);
|
||||
self.output = Some(input);
|
||||
@ -53,10 +64,39 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Adds a constant node with the given value to the graph.
|
||||
///
|
||||
/// Returns an [`Input`] representing the newly-added node, which can be used to construct rules.
|
||||
pub fn add_value<V: NodeValue>(&mut self, value: V) -> Input<V> {
|
||||
return self.add_node(ConstNode::new(value));
|
||||
}
|
||||
|
||||
/// Adds an invalidatable node with the given value to the graph.
|
||||
///
|
||||
/// Returns an [`Input`] and a [`ValueInvalidationSignal`] representing, respectively, the
|
||||
/// newly-added node, which can be used be used to construct rules, and a signal through which
|
||||
/// the value of the node can be replaced, invalidating the node in the process.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
/// let mut builder = GraphBuilder::new();
|
||||
/// let (input, signal) = builder.add_invalidatable_value(0);
|
||||
/// # struct Double(Input<i32>);
|
||||
/// # impl Rule for Double {
|
||||
/// # type Output = i32;
|
||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// # visitor.visit(&self.0);
|
||||
/// # }
|
||||
/// # fn evaluate(&mut self) -> i32 {
|
||||
/// # *self.0.value() * 2
|
||||
/// # }
|
||||
/// # }
|
||||
/// builder.set_output(Double(input));
|
||||
/// let mut graph = builder.build().unwrap();
|
||||
/// assert_eq!(*graph.evaluate(), 0);
|
||||
/// signal.set_value(1);
|
||||
/// assert_eq!(*graph.evaluate(), 2);
|
||||
/// ```
|
||||
pub fn add_invalidatable_value<V: NodeValue>(
|
||||
&mut self,
|
||||
value: V,
|
||||
@ -71,6 +111,9 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
(input, signal)
|
||||
}
|
||||
|
||||
/// Adds a node whose value is produced using the given rule to the graph.
|
||||
///
|
||||
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
|
||||
pub fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||
where
|
||||
R: Rule,
|
||||
@ -78,22 +121,61 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
return self.add_node(RuleNode::new(rule));
|
||||
}
|
||||
|
||||
pub fn add_invalidatable_rule<R, F>(&mut self, mut f: F) -> Input<R::Output>
|
||||
where
|
||||
R: Rule,
|
||||
F: FnMut(InvalidationSignal<S>) -> R,
|
||||
{
|
||||
let node_idx = Rc::new(Cell::new(None));
|
||||
/// Adds an externally-invalidatable node whose value is produced using the given rule to the graph.
|
||||
///
|
||||
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules,
|
||||
/// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
/// let mut builder = GraphBuilder::new();
|
||||
/// # struct IncrementAfterEvaluate(i32);
|
||||
/// # impl Rule for IncrementAfterEvaluate {
|
||||
/// # type Output = i32;
|
||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {}
|
||||
/// # fn evaluate(&mut self) -> i32 {
|
||||
/// # let result = self.0;
|
||||
/// # self.0 += 1;
|
||||
/// # result
|
||||
/// # }
|
||||
/// # }
|
||||
/// # struct Double(Input<i32>);
|
||||
/// # impl Rule for Double {
|
||||
/// # type Output = i32;
|
||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// # visitor.visit(&self.0);
|
||||
/// # }
|
||||
/// # fn evaluate(&mut self) -> i32 {
|
||||
/// # *self.0.value() * 2
|
||||
/// # }
|
||||
/// # }
|
||||
/// let (input, signal) = builder.add_invalidatable_rule(IncrementAfterEvaluate(1));
|
||||
/// builder.set_output(Double(input));
|
||||
/// let mut graph = builder.build().unwrap();
|
||||
/// assert_eq!(*graph.evaluate(), 2);
|
||||
/// signal.invalidate();
|
||||
/// assert_eq!(*graph.evaluate(), 4);
|
||||
/// ```
|
||||
pub fn add_invalidatable_rule<R: Rule>(
|
||||
&mut self,
|
||||
rule: R,
|
||||
) -> (Input<R::Output>, InvalidationSignal<S>) {
|
||||
let input = self.add_rule(rule);
|
||||
let signal = InvalidationSignal {
|
||||
node_idx: Rc::clone(&node_idx),
|
||||
node_idx: Rc::new(Cell::new(Some(input.node_idx))),
|
||||
graph: Rc::clone(&self.node_graph),
|
||||
graph_is_valid: Rc::clone(&self.is_valid),
|
||||
};
|
||||
let input = self.add_rule(f(signal));
|
||||
node_idx.set(Some(input.node_idx));
|
||||
input
|
||||
(input, signal)
|
||||
}
|
||||
|
||||
/// Creates a graph from this builder, consuming the builder.
|
||||
///
|
||||
/// To successfully build a graph, there must be an output node set (using either
|
||||
/// [`GraphBuilder::set_output`] or [`GraphBuilder::set_async_output`]) and there canont be any
|
||||
/// cycles in the graph.
|
||||
///
|
||||
/// Any nodes present in the builder not connected to the output node are removed from the graph.
|
||||
pub fn build(self) -> Result<Graph<O, S>, BuildGraphError> {
|
||||
let output: &Input<O> = match &self.output {
|
||||
None => return Err(BuildGraphError::NoOutput),
|
||||
@ -122,7 +204,7 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
let sorted_nodes =
|
||||
petgraph::algo::toposort(&**self.node_graph.borrow(), None).map_err(|_| {
|
||||
// TODO: actually build a vec describing the cycle path for debugging
|
||||
BuildGraphError::Cyclic(vec![])
|
||||
BuildGraphError::Cycle(vec![])
|
||||
})?;
|
||||
|
||||
let graph = Graph {
|
||||
@ -138,11 +220,19 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
}
|
||||
|
||||
impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
/// Sets an asynchronous rule for the output node.
|
||||
///
|
||||
/// The type of the output rule's value is the type of the output value of the overall graph.
|
||||
pub fn set_async_output<R: AsyncRule<Output = O>>(&mut self, rule: R) {
|
||||
let input = self.add_async_rule(rule);
|
||||
self.output = Some(input);
|
||||
}
|
||||
|
||||
// TODO: add_async_value?
|
||||
|
||||
/// Adds a node whose value is produced using the given rule to the graph.
|
||||
///
|
||||
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules.
|
||||
pub fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
|
||||
where
|
||||
R: AsyncRule,
|
||||
@ -150,6 +240,10 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
self.add_node(AsyncRuleNode::new(rule))
|
||||
}
|
||||
|
||||
/// Adds an externally-invalidatable node whose value is produced using the given async rule to the graph.
|
||||
///
|
||||
/// Returns an [`Input`] representing the newly-added node, which can be used to construct further rules,
|
||||
/// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated.
|
||||
pub fn add_invalidatable_async_rule<R, F>(&mut self, mut f: F) -> Input<R::Output>
|
||||
where
|
||||
R: AsyncRule,
|
||||
@ -167,8 +261,14 @@ impl<O: 'static> GraphBuilder<O, Asynchronous> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A reason why a [`GraphBuilder`] can fail to build a graph.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum BuildGraphError {
|
||||
/// No output rule has been specified with [`GraphBuilder::set_output`].
|
||||
NoOutput,
|
||||
Cyclic(Vec<NodeId>),
|
||||
/// There is a cycle in the graph between the given nodes.
|
||||
///
|
||||
/// The first and last element of the `Vec` are the same, with the elements in between representing
|
||||
/// the path on which the cycle was found.
|
||||
Cycle(Vec<NodeId>),
|
||||
}
|
||||
|
@ -1,7 +1,57 @@
|
||||
mod builder;
|
||||
mod node;
|
||||
mod rule;
|
||||
mod synchronicity;
|
||||
//! Facilities for using a directed, acyclic graph to perform computation.
|
||||
//!
|
||||
//! A directed, acyclic graph (DAG) can be used to carry out computations by considering
|
||||
//! each node to have a value and each edge to represent a dependency on the value of one
|
||||
//! node to compute the value of another node. A node's value can either be constant or be
|
||||
//! produced by a rule, which is a piece of code for generating the value of a node given its
|
||||
//! dependencies. For example, an arithmetic operation can be implemented like so:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
//! let mut builder = GraphBuilder::new();
|
||||
//! let a = builder.add_value(1);
|
||||
//! let b = builder.add_value(2);
|
||||
//! # struct Add(Input<i32>, Input<i32>);
|
||||
//! # impl Rule for Add {
|
||||
//! # type Output = i32;
|
||||
//! # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
//! # visitor.visit(&self.0);
|
||||
//! # visitor.visit(&self.1);
|
||||
//! # }
|
||||
//! # fn evaluate(&mut self) -> i32 {
|
||||
//! # *self.0.value() + *self.1.value()
|
||||
//! # }
|
||||
//! # }
|
||||
//! builder.set_output(Add(a, b));
|
||||
//!
|
||||
//! let mut graph = builder.build().unwrap();
|
||||
//! assert_eq!(*graph.evaluate(), 3);
|
||||
//! ```
|
||||
//!
|
||||
//! Here, `a` and `b` are placeholders representing the values of the two constant nodes in the graph.
|
||||
//! The `Add` struct implements the [`Rule`] trait and defines how to combine those two values by addition.
|
||||
//! The `Add` rule is implemented as follows:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
//! struct Add(Input<i32>, Input<i32>);
|
||||
//!
|
||||
//! impl Rule for Add {
|
||||
//! type Output = i32;
|
||||
//! fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
//! visitor.visit(&self.0);
|
||||
//! visitor.visit(&self.1);
|
||||
//! }
|
||||
//! fn evaluate(&mut self) -> i32 {
|
||||
//! *self.0.value() + *self.1.value()
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
pub mod builder;
|
||||
pub mod node;
|
||||
pub mod rule;
|
||||
pub mod synchronicity;
|
||||
mod util;
|
||||
|
||||
use builder::{BuildGraphError, GraphBuilder};
|
||||
@ -9,7 +59,7 @@ use node::{ErasedNode, NodeValue};
|
||||
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
|
||||
use petgraph::{stable_graph::StableDiGraph, visit::EdgeRef};
|
||||
use rule::{AsyncRule, Input, InputVisitor, Rule};
|
||||
use std::cell::{Cell, RefCell};
|
||||
use std::cell::{Cell, Ref, RefCell};
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::ops::{Deref, DerefMut};
|
||||
@ -40,6 +90,12 @@ impl<S: Synchronicity> DerefMut for NodeGraph<S> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A constructed graph that can evaluated.
|
||||
///
|
||||
/// Use [`GraphBuilder`] to construct a graph.
|
||||
///
|
||||
/// The graph is generic over the type of the output node's value and the [`Synchronicity`]
|
||||
/// —that is, whether it can be evaluated synchronously or asynchronously.
|
||||
pub struct Graph<Output, Synch: Synchronicity> {
|
||||
node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
output: Input<Output>,
|
||||
@ -49,17 +105,34 @@ pub struct Graph<Output, Synch: Synchronicity> {
|
||||
is_valid: Rc<Cell<bool>>,
|
||||
}
|
||||
|
||||
/// A synchronous graph, containing only sync nodes.
|
||||
pub type SyncGraph<Output> = Graph<Output, Synchronous>;
|
||||
|
||||
/// An asynchronous graph, containing a mix of sync and async nodes.
|
||||
pub type AsyncGraph<Output> = Graph<Output, Asynchronous>;
|
||||
|
||||
impl<O: 'static, S: Synchronicity> Graph<O, S> {
|
||||
/// Whether the output value of the graph is currently valid.
|
||||
///
|
||||
/// The output is considered presumptively invalid if _any_ of the nodes in the graph are invalid,
|
||||
/// even if, when evaluated, the invalid node's value is unchanged (in which case, downstream nodes
|
||||
/// are not invalidated) and the output may be unchanged.
|
||||
pub fn is_output_valid(&self) -> bool {
|
||||
let graph = self.node_graph.borrow();
|
||||
let node = &graph[self.output.node_idx];
|
||||
self.is_valid.get() && node.is_valid()
|
||||
self.is_valid.get()
|
||||
}
|
||||
|
||||
/// The number of nodes in the graph.
|
||||
pub fn node_count(&self) -> usize {
|
||||
self.node_graph.borrow().node_count()
|
||||
}
|
||||
|
||||
/// Modify the graph using the given function.
|
||||
///
|
||||
/// The function receives as its parameter a [`GraphBuilder`] representing the current graph.
|
||||
///
|
||||
/// Because building a graph can fail and this method mutates the underlying graph, it takes
|
||||
/// ownership of the current graph to prevent the graph being left in an invalid state.
|
||||
/// It returns either the new, modified graph or an error.
|
||||
pub fn modify<F>(mut self, mut f: F) -> Result<Self, BuildGraphError>
|
||||
where
|
||||
F: FnMut(&mut GraphBuilder<O, S>) -> (),
|
||||
@ -129,11 +202,13 @@ impl<O: 'static> Graph<O, Synchronous> {
|
||||
let value_changed = node.update();
|
||||
|
||||
if value_changed {
|
||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
|
||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
||||
// we're iterating over a topological sort of the graph).
|
||||
let dependents = graph
|
||||
.edges_directed(idx, petgraph::Direction::Outgoing)
|
||||
.map(|edge| edge.target())
|
||||
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
|
||||
// Need to collect because the edges_directed iterator borrows the graph, and
|
||||
// we need to mutably borrow to invalidate.
|
||||
.collect::<Vec<_>>();
|
||||
for dependent_idx in dependents {
|
||||
let dependent = &mut graph[dependent_idx];
|
||||
@ -142,7 +217,8 @@ impl<O: 'static> Graph<O, Synchronous> {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Consistency check: after updating in the topological sort order, we should be left with no invalid nodes
|
||||
// Consistency check: after updating in the topological sort order, we should be left with
|
||||
// no invalid nodes
|
||||
debug_assert!(self
|
||||
.sorted_nodes
|
||||
.iter()
|
||||
@ -150,6 +226,13 @@ impl<O: 'static> Graph<O, Synchronous> {
|
||||
self.is_valid.set(true);
|
||||
}
|
||||
|
||||
/// Synchronously evaluate the graph and return a reference to the value of the output node.
|
||||
///
|
||||
/// If the graph is valid (see [`Graph::is_output_valid`]), this is a constant-time operation.
|
||||
/// Otherwise, any invalid nodes and their downstream dependents will be updated, which is an
|
||||
/// O(n) operation.
|
||||
///
|
||||
/// This method is only available on synchronous graphs, which can only contain synchronous nodes.
|
||||
pub fn evaluate(&mut self) -> impl Deref<Target = O> + '_ {
|
||||
if !self.is_valid.get() {
|
||||
self.update_invalid_nodes();
|
||||
@ -169,11 +252,13 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
||||
let value_changed = node.update().await;
|
||||
|
||||
if value_changed {
|
||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because we're iterating over a topological sort of the graph)
|
||||
// Invalidate any downstream nodes (which we know we haven't visited yet, because
|
||||
// we're iterating over a topological sort of the graph).
|
||||
let dependents = graph
|
||||
.edges_directed(idx, petgraph::Direction::Outgoing)
|
||||
.map(|edge| edge.target())
|
||||
// Need to collect because the edges_directed iterator borrows the graph, and we need to mutably borrow to invalidate
|
||||
// Need to collect because the edges_directed iterator borrows the graph, and
|
||||
// we need to mutably borrow to invalidate.
|
||||
.collect::<Vec<_>>();
|
||||
for dependent_idx in dependents {
|
||||
let dependent = &mut graph[dependent_idx];
|
||||
@ -182,7 +267,8 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Consistency check: after updating in the topological sort order, we should be left with no invalid nodes
|
||||
// Consistency check: after updating in the topological sort order, we should be left with
|
||||
// no invalid nodes
|
||||
debug_assert!(self
|
||||
.sorted_nodes
|
||||
.iter()
|
||||
@ -190,6 +276,14 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
||||
self.is_valid.set(true);
|
||||
}
|
||||
|
||||
/// Asynchronously evaluate the graph and return a reference to the value of the output node.
|
||||
///
|
||||
/// If the graph is valid (see [`Graph::is_output_valid`]), this is a constant-time operation.
|
||||
/// Otherwise, any invalid nodes and their downstream dependents will be updated, which is an
|
||||
/// O(n) operation.
|
||||
///
|
||||
/// This method is only available on asynchronous graphs, which can contain a mix of asynchronous
|
||||
/// and synchronous nodes.
|
||||
pub async fn evaluate_async(&mut self) -> impl Deref<Target = O> + '_ {
|
||||
if !self.is_valid.get() {
|
||||
self.update_invalid_nodes().await;
|
||||
@ -198,7 +292,13 @@ impl<O: 'static> Graph<O, Asynchronous> {
|
||||
}
|
||||
}
|
||||
|
||||
/// A type representing a node in a graph that can be invalidated due to external factors.
|
||||
///
|
||||
/// See [`GraphBuilder::add_invalidatable_rule`].
|
||||
///
|
||||
/// `InvalidationSignal` implements `Clone`, so the signal can be cloned and used from multiple places.
|
||||
// TODO: there's a lot happening here, make sure this doesn't create a reference cycle
|
||||
// TODO: would be better if this didn't have to be generic over Synchronicity
|
||||
pub struct InvalidationSignal<Synch: Synchronicity> {
|
||||
node_idx: Rc<Cell<Option<NodeId>>>,
|
||||
graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||
@ -206,6 +306,11 @@ pub struct InvalidationSignal<Synch: Synchronicity> {
|
||||
}
|
||||
|
||||
impl<S: Synchronicity> InvalidationSignal<S> {
|
||||
/// Tell the graph that the node corresponding to this signal is now invalid.
|
||||
///
|
||||
/// Note: Calling this method does not trigger a graph evaluation, it merely marks the corresponding
|
||||
/// node as invalid. The graph will not be re-evaluated until [`Graph::evaluate`] or
|
||||
/// [`Graph::evaluate_async`] is next called.
|
||||
pub fn invalidate(&self) {
|
||||
self.graph_is_valid.set(false);
|
||||
let mut graph = self.graph.borrow_mut();
|
||||
@ -214,6 +319,19 @@ impl<S: Synchronicity> InvalidationSignal<S> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<S: Synchronicity> Clone for InvalidationSignal<S> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
node_idx: Rc::clone(&self.node_idx),
|
||||
graph: Rc::clone(&self.graph),
|
||||
graph_is_valid: Rc::clone(&self.graph_is_valid),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A type representing a node with an externally injected value.
|
||||
///
|
||||
/// See [`GraphBuilder::add_invalidatable_value`].
|
||||
pub struct ValueInvalidationSignal<V, Synch: Synchronicity> {
|
||||
node_idx: NodeId,
|
||||
value: Rc<RefCell<Option<V>>>,
|
||||
@ -222,6 +340,19 @@ pub struct ValueInvalidationSignal<V, Synch: Synchronicity> {
|
||||
}
|
||||
|
||||
impl<V: NodeValue, S: Synchronicity> ValueInvalidationSignal<V, S> {
|
||||
/// Get a reference to current value for the node corresponding to this signal.
|
||||
pub fn value(&self) -> impl Deref<Target = V> + '_ {
|
||||
Ref::map(self.value.borrow(), |opt| {
|
||||
opt.as_ref()
|
||||
.expect("invalidatable value node must be initialized with value")
|
||||
})
|
||||
}
|
||||
|
||||
/// Set a new value for the node corresponding to this signal.
|
||||
///
|
||||
/// Note: Calling this method does not trigger a graph evaluation, it merely sets a new value
|
||||
/// for the corresponding node. The graph will not be re-evaluated until [`Graph::evaluate`] or
|
||||
/// [`Graph::evaluate_async`] is next called.
|
||||
pub fn set_value(&self, value: V) {
|
||||
let mut current_value = self.value.borrow_mut();
|
||||
if !current_value
|
||||
@ -238,8 +369,6 @@ impl<V: NodeValue, S: Synchronicity> ValueInvalidationSignal<V, S> {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: i really want Input to be able to implement Deref somehow
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -303,18 +432,14 @@ mod tests {
|
||||
#[test]
|
||||
fn invalidatable_rule() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let mut invalidate = None;
|
||||
let input = builder.add_invalidatable_rule(|inv| {
|
||||
invalidate = Some(inv);
|
||||
Inc(0)
|
||||
});
|
||||
let (input, invalidate) = builder.add_invalidatable_rule(Inc(0));
|
||||
builder.set_output(Double(input));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 2);
|
||||
invalidate.as_ref().unwrap().invalidate();
|
||||
invalidate.invalidate();
|
||||
assert_eq!(*graph.evaluate(), 4);
|
||||
assert_eq!(*graph.evaluate(), 4);
|
||||
invalidate.as_ref().unwrap().invalidate();
|
||||
invalidate.invalidate();
|
||||
assert_eq!(*graph.evaluate(), 6);
|
||||
}
|
||||
|
||||
@ -342,16 +467,12 @@ mod tests {
|
||||
#[test]
|
||||
fn rule_with_invalidatable_inputs() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let mut invalidate = None;
|
||||
let a = builder.add_invalidatable_rule(|inv| {
|
||||
invalidate = Some(inv);
|
||||
Inc(0)
|
||||
});
|
||||
let (a, invalidate) = builder.add_invalidatable_rule(Inc(0));
|
||||
let b = builder.add_rule(Inc(0));
|
||||
builder.set_output(Add(a, b));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 2);
|
||||
invalidate.as_ref().unwrap().invalidate();
|
||||
invalidate.invalidate();
|
||||
assert_eq!(*graph.evaluate(), 3);
|
||||
assert_eq!(*graph.evaluate(), 3);
|
||||
}
|
||||
@ -389,7 +510,7 @@ mod tests {
|
||||
*a_input.borrow_mut() = Some(b.clone());
|
||||
builder.set_output(Double(b));
|
||||
match builder.build() {
|
||||
Err(BuildGraphError::Cyclic(_)) => (),
|
||||
Err(BuildGraphError::Cycle(_)) => (),
|
||||
Err(e) => assert!(false, "unexpected error {:?}", e),
|
||||
Ok(_) => assert!(false, "shouldn't have frozen graph"),
|
||||
}
|
||||
@ -472,11 +593,7 @@ mod tests {
|
||||
#[test]
|
||||
fn only_update_downstream_nodes_if_value_changes() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let mut invalidate = None;
|
||||
let a = builder.add_invalidatable_rule(|inv| {
|
||||
invalidate = Some(inv);
|
||||
ConstantRule::new(0)
|
||||
});
|
||||
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
||||
struct IncAdd(Input<i32>, i32);
|
||||
impl Rule for IncAdd {
|
||||
type Output = i32;
|
||||
@ -493,7 +610,7 @@ mod tests {
|
||||
assert_eq!(*graph.evaluate(), 1);
|
||||
|
||||
// IncAdd should not be evaluated again, despite its input being invalidated, so the output should be unchanged
|
||||
invalidate.unwrap().invalidate();
|
||||
invalidate.invalidate();
|
||||
assert!(!graph.is_output_valid());
|
||||
assert_eq!(*graph.evaluate(), 1);
|
||||
}
|
||||
|
@ -66,8 +66,32 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity> {
|
||||
fn value_rc(&self) -> &Rc<RefCell<Option<Value>>>;
|
||||
}
|
||||
|
||||
/// A value that can be used as the value of a node in the graph.
|
||||
///
|
||||
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed
|
||||
/// and thus whether downstream nodes need to be invalidated too.
|
||||
///
|
||||
/// A blanket implementation of this trait for all types implementing `PartialEq` is provided.
|
||||
pub trait NodeValue: 'static {
|
||||
fn node_value_eq(&self, other: &Self) -> bool;
|
||||
/// Whether self is equal, for the purposes of graph invalidation, from other.
|
||||
///
|
||||
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
|
||||
/// determined, this method should return `false`.
|
||||
///
|
||||
/// The default implementation of this method always returns `false`, so any non-`PartialEq` type can
|
||||
/// implement this trait simply:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::node::NodeValue;
|
||||
/// struct MyType;
|
||||
/// impl NodeValue for MyType {}
|
||||
/// ```
|
||||
///
|
||||
/// Note that always returning `false` may result in more node invalidations than strictly necessary.
|
||||
#[allow(unused_variables)]
|
||||
fn node_value_eq(&self, other: &Self) -> bool {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialEq + 'static> NodeValue for T {
|
||||
|
@ -4,22 +4,95 @@ use std::cell::{Ref, RefCell};
|
||||
use std::ops::Deref;
|
||||
use std::rc::Rc;
|
||||
|
||||
/// A rule produces a value for a graph node using its [`Input`]s.
|
||||
///
|
||||
/// A rule for addition could be implemented like so:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::rule::{Rule, Input, InputVisitor};
|
||||
/// struct Add(Input<i32>, Input<i32>);
|
||||
///
|
||||
/// impl Rule for Add {
|
||||
/// type Output = i32;
|
||||
///
|
||||
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// visitor.visit(&self.0);
|
||||
/// visitor.visit(&self.1);
|
||||
/// }
|
||||
///
|
||||
/// fn evaluate(&mut self) -> Self::Output {
|
||||
/// *self.0.value() + *self.1.value()
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait Rule: 'static {
|
||||
/// The type of the output value of the rule.
|
||||
type Output: NodeValue;
|
||||
|
||||
/// Visits all the [`Input`]s of this rule.
|
||||
///
|
||||
/// This method is called when the graph is built/modified in order to establish edges of the graph,
|
||||
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
|
||||
/// considered a dependency of the rule's node.
|
||||
///
|
||||
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
|
||||
/// doing so only permitted before the graph has been built or during the callback of
|
||||
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
|
||||
/// not be detected and will not be represented in the graph.
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||
|
||||
/// Produces the value of this rule using its inputs.
|
||||
///
|
||||
/// Note that the receiver of this method is a mutable reference to the rule itself. Rules are permitted
|
||||
/// to have internal state that they modify during evaluation.
|
||||
///
|
||||
/// The following guarantees are made about rule evaluation:
|
||||
/// 1. A rule will only be evaluated when one or more of its dependencies has changed. Note that "changed"
|
||||
/// referes to returning `false` from [`NodeValue::node_value_eq`] for the dependency.
|
||||
/// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never
|
||||
/// be evaluated with mix of valid and invalid dependencies.
|
||||
fn evaluate(&mut self) -> Self::Output;
|
||||
}
|
||||
|
||||
/// A rule produces a value for a graph node asynchronously.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitor};
|
||||
/// # async fn do_async_work(_: i32) -> i32 { 0 }
|
||||
/// struct AsyncMath(Input<i32>);
|
||||
///
|
||||
/// impl AsyncRule for AsyncMath {
|
||||
/// type Output = i32;
|
||||
///
|
||||
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// visitor.visit(&self.0);
|
||||
/// }
|
||||
///
|
||||
/// async fn evaluate(&mut self) -> Self::Output {
|
||||
/// do_async_work(*self.0.value()).await
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait AsyncRule: 'static {
|
||||
/// The type of the output value of the rule.
|
||||
type Output: NodeValue;
|
||||
|
||||
/// Visits all the [`Input`]s of this rule.
|
||||
///
|
||||
/// See [`Rule::visit_inputs`] for additional details; the same caveats apply.
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||
|
||||
/// Asynchronously produces the value of this rule using its inputs.
|
||||
///
|
||||
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
|
||||
async fn evaluate(&mut self) -> Self::Output;
|
||||
}
|
||||
|
||||
/// A placeholder for the output of one node to be used as an input for another.
|
||||
///
|
||||
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
|
||||
///
|
||||
/// Note that this type implements `Clone`, so can be cloned and used as an input for multiple nodes.
|
||||
#[derive(Debug)]
|
||||
pub struct Input<T> {
|
||||
pub(crate) node_idx: NodeId,
|
||||
@ -27,6 +100,9 @@ pub struct Input<T> {
|
||||
}
|
||||
|
||||
impl<T> Input<T> {
|
||||
/// Retrieves a reference to the current value of the node the input represents.
|
||||
///
|
||||
/// Calling this method before the node it represents has been evaluated will panic.
|
||||
pub fn value(&self) -> impl Deref<Target = T> + '_ {
|
||||
Ref::map(self.value.borrow(), |opt| {
|
||||
opt.as_ref()
|
||||
@ -47,13 +123,21 @@ impl<T> Clone for Input<T> {
|
||||
|
||||
// TODO: i really want Input to be able to implement Deref somehow
|
||||
|
||||
/// A type that can visit arbitrary [`Input`]s.
|
||||
///
|
||||
/// You generally do not implement this trait yourself. An implementation is provided to [`Rule::visit_inputs`].
|
||||
pub trait InputVisitor {
|
||||
/// Visit an input whose value is of type `T`.
|
||||
fn visit<T>(&mut self, input: &Input<T>);
|
||||
}
|
||||
|
||||
/// A simple rule that provides a constant value.
|
||||
///
|
||||
/// Note that, because [`Rule::evaluate`] returns an owned value, this rule's value type must implement `Clone`.
|
||||
pub struct ConstantRule<T>(T);
|
||||
|
||||
impl<T> ConstantRule<T> {
|
||||
/// Constructs a new constant rule with the given value.
|
||||
pub fn new(value: T) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
|
@ -1,9 +1,21 @@
|
||||
//! Types used to make [`Graph`](`crate::Graph`) and [`GraphBuilder`](`crate::builder::GraphBuilder`) generic
|
||||
//! over the synchronicity of the graph.
|
||||
//!
|
||||
//! The [`Synchronicity`] trait is sealed to outside implementors, and you generally do not need to refer
|
||||
//! directly to the [`Synchronous`] or [`Asynchronous`] types.
|
||||
|
||||
use crate::node::{Node, NodeValue};
|
||||
use std::any::Any;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
pub trait Synchronicity: 'static {
|
||||
mod private {
|
||||
pub trait Sealed {}
|
||||
impl Sealed for super::Synchronous {}
|
||||
impl Sealed for super::Asynchronous {}
|
||||
}
|
||||
|
||||
pub trait Synchronicity: private::Sealed + 'static {
|
||||
type UpdateFn;
|
||||
fn make_update_fn<V: NodeValue>() -> Self::UpdateFn;
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user