Assorted graph tweaks

This commit is contained in:
Shadowfacts 2024-11-03 14:42:45 -05:00
parent 712b528ca8
commit e69014d98d
4 changed files with 95 additions and 11 deletions

View File

@ -333,6 +333,14 @@ pub enum BuildGraphError {
Cycle(Vec<NodeId>),
}
impl std::fmt::Display for BuildGraphError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl std::error::Error for BuildGraphError {}
#[cfg(test)]
mod tests {
use crate::tests::{DeferredInput, Double};

View File

@ -618,6 +618,11 @@ mod tests {
fn non_cloneable_output() {
#[derive(PartialEq, Debug)]
struct NonCloneable;
impl NodeValue for NonCloneable {
fn node_value_eq(&self, _other: &Self) -> bool {
true
}
}
struct Output;
impl InputVisitable for Output {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
@ -689,7 +694,23 @@ mod tests {
let mut builder = GraphBuilder::new();
let a = builder.add_value(1);
let b = builder.add_value(2);
builder.set_output(Add(a, b));
struct AddWithLabel(Input<i32>, Input<i32>);
impl InputVisitable for AddWithLabel {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
visitor.visit(&self.1);
}
}
impl Rule for AddWithLabel {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
*self.0.value() + *self.1.value()
}
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "test")
}
}
builder.set_output(AddWithLabel(a, b));
let graph = builder.build().unwrap();
println!("{}", graph.as_dot_string());
assert_eq!(
@ -697,7 +718,7 @@ mod tests {
r#"digraph {
0 [label ="ConstNode<i32> (id=0)"]
1 [label ="ConstNode<i32> (id=1)"]
2 [label ="RuleNode<compute_graph::tests::Add> (id=2)"]
2 [label ="RuleNode<compute_graph::tests::graphviz::AddWithLabel>(test) (id=2)"]
0 -> 2 []
1 -> 2 []
}

View File

@ -85,16 +85,13 @@ pub(crate) trait Node<Value: NodeValue, Synch: Synchronicity>: std::fmt::Debug {
///
/// This trait is used to determine, when a node is invalidated, whether its value has truly changed
/// and thus whether downstream nodes need to be invalidated too.
///
/// A blanket implementation of this trait for all types implementing `PartialEq` is provided.
pub trait NodeValue: 'static {
/// Whether self is equal, for the purposes of graph invalidation, from other.
///
/// This method should be conservative. That is, if the equality of the two values cannot be affirmatively
/// determined, this method should return `false`.
///
/// The default implementation of this method always returns `false`, so any non-`PartialEq` type can
/// implement this trait simply:
/// The default implementation of this method always returns `false`.
///
/// ```rust
/// # use compute_graph::node::NodeValue;
@ -109,12 +106,36 @@ pub trait NodeValue: 'static {
}
}
impl<T: PartialEq + 'static> NodeValue for T {
impl NodeValue for () {
fn node_value_eq(&self, _other: &Self) -> bool {
true
}
}
impl NodeValue for i32 {
fn node_value_eq(&self, other: &Self) -> bool {
self == other
}
}
impl<T: NodeValue> NodeValue for Option<T> {
fn node_value_eq(&self, other: &Self) -> bool {
match (self, other) {
(Some(s), Some(o)) => s.node_value_eq(o),
_ => false,
}
}
}
impl<T: NodeValue, E: 'static> NodeValue for Result<T, E> {
fn node_value_eq(&self, other: &Self) -> bool {
match (self, other) {
(Ok(s), Ok(o)) => s.node_value_eq(o),
_ => false,
}
}
}
pub(crate) struct ConstNode<V, S> {
value: Rc<RefCell<Option<V>>>,
synchronicity: std::marker::PhantomData<S>,
@ -256,9 +277,21 @@ impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S>
}
}
impl<R, V, S> std::fmt::Debug for RuleNode<R, V, S> {
struct RuleLabel<'a, R: Rule>(&'a R);
impl<'a, R: Rule> std::fmt::Display for RuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RuleNode<{}>", std::any::type_name::<R>())
self.0.node_label(f)
}
}
impl<R: Rule, V, S> std::fmt::Debug for RuleNode<R, V, S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"RuleNode<{}>({})",
std::any::type_name::<R>(),
RuleLabel(&self.rule)
)
}
}
@ -374,8 +407,20 @@ impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
}
}
impl<R, V> std::fmt::Debug for AsyncRuleNode<R, V> {
struct AsyncRuleLabel<'a, R: AsyncRule>(&'a R);
impl<'a, R: AsyncRule> std::fmt::Display for AsyncRuleLabel<'a, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "AsyncRuleNode<{}>", std::any::type_name::<R>())
self.0.node_label(f)
}
}
impl<R: AsyncRule, V> std::fmt::Debug for AsyncRuleNode<R, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"AsyncRuleNode<{}>({})",
std::any::type_name::<R>(),
AsyncRuleLabel(&self.rule)
)
}
}

View File

@ -38,6 +38,11 @@ pub trait Rule: InputVisitable + 'static {
/// 2. A rule will never be evaluated before _all_ of its dependencies up-to-date. That is, it will never
/// be evaluated with mix of valid and invalid dependencies.
fn evaluate(&mut self) -> Self::Output;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
}
/// A rule produces a value for a graph node asynchronously.
@ -64,6 +69,11 @@ pub trait AsyncRule: InputVisitable + 'static {
///
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
fn evaluate(&mut self) -> impl Future<Output = Self::Output> + '_;
#[allow(unused_variables)]
fn node_label(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
Ok(())
}
}
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.