Assorted graph tweaks
This commit is contained in:
parent
712b528ca8
commit
e69014d98d
@ -333,6 +333,14 @@ pub enum BuildGraphError {
|
|||||||
Cycle(Vec<NodeId>),
|
Cycle(Vec<NodeId>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Display for BuildGraphError {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "{self:?}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::error::Error for BuildGraphError {}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use crate::tests::{DeferredInput, Double};
|
use crate::tests::{DeferredInput, Double};
|
||||||
|
@ -618,6 +618,11 @@ mod tests {
|
|||||||
fn non_cloneable_output() {
|
fn non_cloneable_output() {
|
||||||
#[derive(PartialEq, Debug)]
|
#[derive(PartialEq, Debug)]
|
||||||
struct NonCloneable;
|
struct NonCloneable;
|
||||||
|
impl NodeValue for NonCloneable {
|
||||||
|
fn node_value_eq(&self, _other: &Self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
struct Output;
|
struct Output;
|
||||||
impl InputVisitable for Output {
|
impl InputVisitable for Output {
|
||||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||||
@ -689,7 +694,23 @@ mod tests {
|
|||||||
let mut builder = GraphBuilder::new();
|
let mut builder = GraphBuilder::new();
|
||||||
let a = builder.add_value(1);
|
let a = builder.add_value(1);
|
||||||
let b = builder.add_value(2);
|
let b = builder.add_value(2);
|
||||||
builder.set_output(Add(a, b));
|
struct AddWithLabel(Input<i32>, Input<i32>);
|
||||||
|
impl InputVisitable for AddWithLabel {
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
|
visitor.visit(&self.0);
|
||||||
|
visitor.visit(&self.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Rule for AddWithLabel {
|
||||||
|
type Output = i32;
|
||||||
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
|
*self.0.value() + *self.1.value()
|
||||||
|
}
|
||||||
|
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
write!(f, "test")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
builder.set_output(AddWithLabel(a, b));
|
||||||
let graph = builder.build().unwrap();
|
let graph = builder.build().unwrap();
|
||||||
println!("{}", graph.as_dot_string());
|
println!("{}", graph.as_dot_string());
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -697,7 +718,7 @@ mod tests {
|
|||||||
r#"digraph {
|
r#"digraph {
|
||||||
0 [label ="ConstNode<i32> (id=0)"]
|
0 [label ="ConstNode<i32> (id=0)"]
|
||||||
1 [label ="ConstNode<i32> (id=1)"]
|
1 [label ="ConstNode<i32> (id=1)"]
|
||||||
2 [label ="RuleNode<compute_graph::tests::Add> (id=2)"]
|
2 [label ="RuleNode<compute_graph::tests::graphviz::AddWithLabel>(test) (id=2)"]
|
||||||
0 -> 2 []
|
0 -> 2 []
|
||||||
1 -> 2 []
|
1 -> 2 []
|
||||||
}
|
}
|
||||||
|
@ -85,16 +85,13 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
|
|||||||
///
|
///
|
||||||
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed
|
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed
|
||||||
/// and thus whether downstream nodes need to be invalidated too.
|
/// and thus whether downstream nodes need to be invalidated too.
|
||||||
///
|
|
||||||
/// A blanket implementation of this trait for all types implementing `PartialEq` is provided.
|
|
||||||
pub trait NodeValue: 'static {
|
pub trait NodeValue: 'static {
|
||||||
/// Whether self is equal, for the purposes of graph invalidation, from other.
|
/// Whether self is equal, for the purposes of graph invalidation, from other.
|
||||||
///
|
///
|
||||||
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
|
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
|
||||||
/// determined, this method should return `false`.
|
/// determined, this method should return `false`.
|
||||||
///
|
///
|
||||||
/// The default implementation of this method always returns `false`, so any non-`PartialEq` type can
|
/// The default implementation of this method always returns `false`.
|
||||||
/// implement this trait simply:
|
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// # use compute_graph::node::NodeValue;
|
/// # use compute_graph::node::NodeValue;
|
||||||
@ -109,12 +106,36 @@ pub trait NodeValue: 'static {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: PartialEq + 'static> NodeValue for T {
|
impl NodeValue for () {
|
||||||
|
fn node_value_eq(&self, _other: &Self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NodeValue for i32 {
|
||||||
fn node_value_eq(&self, other: &Self) -> bool {
|
fn node_value_eq(&self, other: &Self) -> bool {
|
||||||
self == other
|
self == other
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: NodeValue> NodeValue for Option<T> {
|
||||||
|
fn node_value_eq(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
(Some(s), Some(o)) => s.node_value_eq(o),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: NodeValue, E: 'static> NodeValue for Result<T, E> {
|
||||||
|
fn node_value_eq(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
(Ok(s), Ok(o)) => s.node_value_eq(o),
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) struct ConstNode<V, S> {
|
pub(crate) struct ConstNode<V, S> {
|
||||||
value: Rc<RefCell<Option<V>>>,
|
value: Rc<RefCell<Option<V>>>,
|
||||||
synchronicity: std::marker::PhantomData<S>,
|
synchronicity: std::marker::PhantomData<S>,
|
||||||
@ -256,9 +277,21 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, V, S> std::fmt::Debug for RuleNode<R, V, S> {
|
struct RuleLabel<'a, R: Rule>(&'a R);
|
||||||
|
impl<'a, R: Rule> std::fmt::Display for RuleLabel<'a, R> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "RuleNode<{}>", std::any::type_name::<R>())
|
self.0.node_label(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: Rule, V, S> std::fmt::Debug for RuleNode<R, V, S> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"RuleNode<{}>({})",
|
||||||
|
std::any::type_name::<R>(),
|
||||||
|
RuleLabel(&self.rule)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -374,8 +407,20 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R, V> std::fmt::Debug for AsyncRuleNode<R, V> {
|
struct AsyncRuleLabel<'a, R: AsyncRule>(&'a R);
|
||||||
|
impl<'a, R: AsyncRule> std::fmt::Display for AsyncRuleLabel<'a, R> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "AsyncRuleNode<{}>", std::any::type_name::<R>())
|
self.0.node_label(f)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"AsyncRuleNode<{}>({})",
|
||||||
|
std::any::type_name::<R>(),
|
||||||
|
AsyncRuleLabel(&self.rule)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -38,6 +38,11 @@ pub trait Rule: InputVisitable + 'static {
|
|||||||
/// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never
|
/// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never
|
||||||
/// be evaluated with mix of valid and invalid dependencies.
|
/// be evaluated with mix of valid and invalid dependencies.
|
||||||
fn evaluate(&mut self) -> Self::Output;
|
fn evaluate(&mut self) -> Self::Output;
|
||||||
|
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A rule produces a value for a graph node asynchronously.
|
/// A rule produces a value for a graph node asynchronously.
|
||||||
@ -64,6 +69,11 @@ pub trait AsyncRule: InputVisitable + 'static {
|
|||||||
///
|
///
|
||||||
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
|
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
|
||||||
fn evaluate(&mut self) -> impl Future<Output = Self::Output> + '_;
|
fn evaluate(&mut self) -> impl Future<Output = Self::Output> + '_;
|
||||||
|
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
|
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user