172 lines
4.4 KiB
Rust

use compute_graph::input::{DynamicInput, Input, InputVisitable};
use compute_graph::node::NodeValue;
use compute_graph::rule::Rule;
#[derive(InputVisitable)]
struct Add(Input<i32>, Input<i32>, #[skip_visit] i32);
impl Rule for Add {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
*self.input_0() + *self.input_1() + self.2
}
}
#[derive(InputVisitable)]
struct Add2 {
a: Input<i32>,
b: Input<i32>,
#[skip_visit]
c: i32,
}
impl Rule for Add2 {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
*self.a() + *self.b() + self.c
}
}
#[derive(InputVisitable)]
struct Passthrough<T: NodeValue + Clone>(Input<T>);
impl<T: NodeValue + Clone> Rule for Passthrough<T> {
type Output = T;
fn evaluate(&mut self) -> Self::Output {
self.input_0().clone()
}
}
#[derive(InputVisitable)]
struct Sum(DynamicInput<i32>);
impl Rule for Sum {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.input_0()
.inputs
.iter()
.map(|input| *input.value())
.sum()
}
}
#[derive(InputVisitable)]
enum E {
A(#[skip_visit] i32, Input<i32>),
B {
#[skip_visit]
x: i32,
y: Input<i32>,
},
C {
x: Input<i32>,
},
}
#[cfg(test)]
mod tests {
use compute_graph::{
builder::GraphBuilder,
input::InputVisitor,
rule::{ConstantRule, DynamicRule},
synchronicity::Synchronous,
};
use super::*;
#[test]
fn test_add() {
let mut builder = GraphBuilder::new();
let a = builder.add_value(1);
let b = builder.add_value(2);
builder.set_output(Add(a, b, 3));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 6);
}
#[test]
fn test_add2() {
let mut builder = GraphBuilder::new();
let a = builder.add_value(1);
let b = builder.add_value(2);
builder.set_output(Add2 { a, b, c: 3 });
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 6);
}
#[test]
fn test_sum() {
#[derive(InputVisitable)]
struct Dynamic;
impl DynamicRule for Dynamic {
type ChildOutput = i32;
fn evaluate(
&mut self,
ctx: &mut impl compute_graph::rule::DynamicRuleContext,
) -> Vec<Input<Self::ChildOutput>> {
vec![
ctx.add_rule(ConstantRule::new(1)),
ctx.add_rule(ConstantRule::new(2)),
]
}
}
let mut builder = GraphBuilder::new();
let dynamic_input = builder.add_dynamic_rule(Dynamic);
builder.set_output(Sum(dynamic_input));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 3);
}
#[test]
fn test_ignore() {
#[allow(unused)]
#[derive(InputVisitable)]
struct Ignore {
#[skip_visit]
input: Input<i32>,
}
let mut builder = GraphBuilder::<i32, Synchronous>::new();
struct Visitor;
impl InputVisitor for Visitor {
fn visit<T>(&mut self, _input: &Input<T>) {
assert!(false);
}
}
Ignore {
input: builder.add_value(0),
}
.visit_inputs(&mut Visitor);
}
#[test]
fn test_enum() {
let mut builder = GraphBuilder::<i32, Synchronous>::new();
let input = builder.add_value(1);
struct Visitor(bool, Input<i32>);
impl InputVisitor for Visitor {
fn visit<T>(&mut self, input: &Input<T>) {
assert_eq!(input.node_id(), self.1.node_id());
assert!(!self.0);
self.0 = true;
}
}
let a = E::A(0, input.clone());
let mut visitor = Visitor(false, input.clone());
InputVisitable::visit_inputs(&a, &mut visitor);
assert!(visitor.0);
let b = E::B {
x: 0,
y: input.clone(),
};
let mut visitor = Visitor(false, input.clone());
InputVisitable::visit_inputs(&b, &mut visitor);
assert!(visitor.0);
let c = E::C { x: input.clone() };
let mut visitor = Visitor(false, input);
InputVisitable::visit_inputs(&c, &mut visitor);
assert!(visitor.0);
}
}