use compute_graph::input::{DynamicInput, Input, InputVisitable}; use compute_graph::node::NodeValue; use compute_graph::rule::Rule; #[derive(InputVisitable)] struct Add(Input, Input, #[skip_visit] i32); impl Rule for Add { type Output = i32; fn evaluate(&mut self) -> Self::Output { *self.input_0() + *self.input_1() + self.2 } } #[derive(InputVisitable)] struct Add2 { a: Input, b: Input, #[skip_visit] c: i32, } impl Rule for Add2 { type Output = i32; fn evaluate(&mut self) -> Self::Output { *self.a() + *self.b() + self.c } } #[derive(InputVisitable)] struct Passthrough(Input); impl Rule for Passthrough { type Output = T; fn evaluate(&mut self) -> Self::Output { self.input_0().clone() } } #[derive(InputVisitable)] struct Sum(DynamicInput); impl Rule for Sum { type Output = i32; fn evaluate(&mut self) -> Self::Output { self.input_0() .inputs .iter() .map(|input| *input.value()) .sum() } } #[derive(InputVisitable)] enum E { A(#[skip_visit] i32, Input), B { #[skip_visit] x: i32, y: Input, }, C { x: Input, }, } #[cfg(test)] mod tests { use compute_graph::{ builder::GraphBuilder, input::InputVisitor, rule::{ConstantRule, DynamicRule}, synchronicity::Synchronous, }; use super::*; #[test] fn test_add() { let mut builder = GraphBuilder::new(); let a = builder.add_value(1); let b = builder.add_value(2); builder.set_output(Add(a, b, 3)); let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 6); } #[test] fn test_add2() { let mut builder = GraphBuilder::new(); let a = builder.add_value(1); let b = builder.add_value(2); builder.set_output(Add2 { a, b, c: 3 }); let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 6); } #[test] fn test_sum() { #[derive(InputVisitable)] struct Dynamic; impl DynamicRule for Dynamic { type ChildOutput = i32; fn evaluate( &mut self, ctx: &mut impl compute_graph::rule::DynamicRuleContext, ) -> Vec> { vec![ ctx.add_rule(ConstantRule::new(1)), ctx.add_rule(ConstantRule::new(2)), ] } } let mut builder = GraphBuilder::new(); let dynamic_input = builder.add_dynamic_rule(Dynamic); builder.set_output(Sum(dynamic_input)); let mut graph = builder.build().unwrap(); assert_eq!(*graph.evaluate(), 3); } #[test] fn test_ignore() { #[allow(unused)] #[derive(InputVisitable)] struct Ignore { #[skip_visit] input: Input, } let mut builder = GraphBuilder::::new(); struct Visitor; impl InputVisitor for Visitor { fn visit(&mut self, _input: &Input) { assert!(false); } } Ignore { input: builder.add_value(0), } .visit_inputs(&mut Visitor); } #[test] fn test_enum() { let mut builder = GraphBuilder::::new(); let input = builder.add_value(1); struct Visitor(bool, Input); impl InputVisitor for Visitor { fn visit(&mut self, input: &Input) { assert_eq!(input.node_id(), self.1.node_id()); assert!(!self.0); self.0 = true; } } let a = E::A(0, input.clone()); let mut visitor = Visitor(false, input.clone()); InputVisitable::visit_inputs(&a, &mut visitor); assert!(visitor.0); let b = E::B { x: 0, y: input.clone(), }; let mut visitor = Visitor(false, input.clone()); InputVisitable::visit_inputs(&b, &mut visitor); assert!(visitor.0); let c = E::C { x: input.clone() }; let mut visitor = Visitor(false, input); InputVisitable::visit_inputs(&c, &mut visitor); assert!(visitor.0); } }