Derive macro
This commit is contained in:
parent
36bcbe3c9c
commit
08a4bf87dc
17
Cargo.lock
generated
17
Cargo.lock
generated
@ -360,10 +360,20 @@ dependencies = [
|
||||
name = "compute_graph"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"compute_graph_macros",
|
||||
"petgraph",
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "compute_graph_macros"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.85",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "core-foundation"
|
||||
version = "0.9.3"
|
||||
@ -521,6 +531,13 @@ dependencies = [
|
||||
"syn 1.0.99",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derive_test"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"compute_graph",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "digest"
|
||||
version = "0.10.3"
|
||||
|
@ -1,4 +1,4 @@
|
||||
workspace = { members = ["crates/compute_graph"] }
|
||||
workspace = { members = ["crates/compute_graph", "crates/compute_graph_macros", "crates/derive_test"] }
|
||||
|
||||
[package]
|
||||
name = "v6"
|
||||
|
@ -6,6 +6,7 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
compute_graph_macros = { path = "../compute_graph_macros" }
|
||||
petgraph = "0.6.5"
|
||||
|
||||
[dev-dependencies]
|
||||
|
@ -92,17 +92,15 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
/// the value of the node can be replaced, invalidating the node in the process.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||
/// let mut builder = GraphBuilder::new();
|
||||
/// let (input, signal) = builder.add_invalidatable_value(0);
|
||||
/// # #[derive(InputVisitable)]
|
||||
/// # struct Double(Input<i32>);
|
||||
/// # impl Rule for Double {
|
||||
/// # type Output = i32;
|
||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// # visitor.visit(&self.0);
|
||||
/// # }
|
||||
/// # fn evaluate(&mut self) -> i32 {
|
||||
/// # *self.0.value() * 2
|
||||
/// # *self.input_0() * 2
|
||||
/// # }
|
||||
/// # }
|
||||
/// builder.set_output(Double(input));
|
||||
@ -139,26 +137,24 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
||||
/// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||
/// let mut builder = GraphBuilder::new();
|
||||
/// # #[derive(InputVisitable)]
|
||||
/// # struct IncrementAfterEvaluate(i32);
|
||||
/// # impl Rule for IncrementAfterEvaluate {
|
||||
/// # type Output = i32;
|
||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {}
|
||||
/// # fn evaluate(&mut self) -> i32 {
|
||||
/// # let result = self.0;
|
||||
/// # self.0 += 1;
|
||||
/// # result
|
||||
/// # }
|
||||
/// # }
|
||||
/// # #[derive(InputVisitable)]
|
||||
/// # struct Double(Input<i32>);
|
||||
/// # impl Rule for Double {
|
||||
/// # type Output = i32;
|
||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// # visitor.visit(&self.0);
|
||||
/// # }
|
||||
/// # fn evaluate(&mut self) -> i32 {
|
||||
/// # *self.0.value() * 2
|
||||
/// # *self.input_0() * 2
|
||||
/// # }
|
||||
/// # }
|
||||
/// let (input, signal) = builder.add_invalidatable_rule(IncrementAfterEvaluate(1));
|
||||
|
@ -7,19 +7,16 @@
|
||||
//! dependencies. For example, an arithmetic operation can be implemented like so:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||
//! let mut builder = GraphBuilder::new();
|
||||
//! let a = builder.add_value(1);
|
||||
//! let b = builder.add_value(2);
|
||||
//! # #[derive(InputVisitable)]
|
||||
//! # struct Add(Input<i32>, Input<i32>);
|
||||
//! # impl Rule for Add {
|
||||
//! # type Output = i32;
|
||||
//! # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
//! # visitor.visit(&self.0);
|
||||
//! # visitor.visit(&self.1);
|
||||
//! # }
|
||||
//! # fn evaluate(&mut self) -> i32 {
|
||||
//! # *self.0.value() + *self.1.value()
|
||||
//! # *self.input_0() + *self.input_1()
|
||||
//! # }
|
||||
//! # }
|
||||
//! builder.set_output(Add(a, b));
|
||||
@ -33,17 +30,14 @@
|
||||
//! The `Add` rule is implemented as follows:
|
||||
//!
|
||||
//! ```rust
|
||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||
//! #[derive(InputVisitable)]
|
||||
//! struct Add(Input<i32>, Input<i32>);
|
||||
//!
|
||||
//! impl Rule for Add {
|
||||
//! type Output = i32;
|
||||
//! fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
//! visitor.visit(&self.0);
|
||||
//! visitor.visit(&self.1);
|
||||
//! }
|
||||
//! fn evaluate(&mut self) -> i32 {
|
||||
//! *self.0.value() + *self.1.value()
|
||||
//! *self.input_0() + *self.input_1()
|
||||
//! }
|
||||
//! }
|
||||
//! ```
|
||||
@ -373,7 +367,7 @@ impl<V> Clone for ValueInvalidationSignal<V> {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::rule::ConstantRule;
|
||||
use crate::rule::{ConstantRule, InputVisitable};
|
||||
|
||||
#[test]
|
||||
fn rule_output_with_no_inputs() {
|
||||
@ -393,11 +387,13 @@ mod tests {
|
||||
}
|
||||
|
||||
struct Double(Input<i32>);
|
||||
impl Rule for Double {
|
||||
type Output = i32;
|
||||
impl InputVisitable for Double {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit(&self.0);
|
||||
}
|
||||
}
|
||||
impl Rule for Double {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> i32 {
|
||||
*self.0.value() * 2
|
||||
}
|
||||
@ -421,9 +417,11 @@ mod tests {
|
||||
}
|
||||
|
||||
struct Inc(i32);
|
||||
impl InputVisitable for Inc {
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
}
|
||||
impl Rule for Inc {
|
||||
type Output = i32;
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
fn evaluate(&mut self) -> i32 {
|
||||
self.0 += 1;
|
||||
return self.0;
|
||||
@ -445,12 +443,14 @@ mod tests {
|
||||
}
|
||||
|
||||
struct Add(Input<i32>, Input<i32>);
|
||||
impl Rule for Add {
|
||||
type Output = i32;
|
||||
impl InputVisitable for Add {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit(&self.0);
|
||||
visitor.visit(&self.1);
|
||||
}
|
||||
}
|
||||
impl Rule for Add {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> i32 {
|
||||
*self.0.value() + *self.1.value()
|
||||
}
|
||||
@ -489,13 +489,15 @@ mod tests {
|
||||
}
|
||||
|
||||
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
|
||||
impl Rule for DeferredInput {
|
||||
type Output = i32;
|
||||
impl InputVisitable for DeferredInput {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
let borrowed = self.0.borrow();
|
||||
let input = borrowed.as_ref().unwrap();
|
||||
visitor.visit(input);
|
||||
}
|
||||
}
|
||||
impl Rule for DeferredInput {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> i32 {
|
||||
*self.0.borrow().as_ref().unwrap().value()
|
||||
}
|
||||
@ -560,9 +562,11 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn async_rule() {
|
||||
struct AsyncConst(i32);
|
||||
impl InputVisitable for AsyncConst {
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
}
|
||||
impl AsyncRule for AsyncConst {
|
||||
type Output = i32;
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
async fn evaluate(&mut self) -> i32 {
|
||||
self.0
|
||||
}
|
||||
@ -578,9 +582,11 @@ mod tests {
|
||||
#[derive(PartialEq, Debug)]
|
||||
struct NonCloneable;
|
||||
struct Output;
|
||||
impl InputVisitable for Output {
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
}
|
||||
impl Rule for Output {
|
||||
type Output = NonCloneable;
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
NonCloneable
|
||||
}
|
||||
@ -596,11 +602,13 @@ mod tests {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
||||
struct IncAdd(Input<i32>, i32);
|
||||
impl Rule for IncAdd {
|
||||
type Output = i32;
|
||||
impl InputVisitable for IncAdd {
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
visitor.visit(&self.0);
|
||||
}
|
||||
}
|
||||
impl Rule for IncAdd {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
self.1 += 1;
|
||||
*self.0.value() + self.1
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::node::NodeValue;
|
||||
use crate::NodeId;
|
||||
pub use compute_graph_macros::InputVisitable;
|
||||
use std::cell::{Ref, RefCell};
|
||||
use std::ops::Deref;
|
||||
use std::rc::Rc;
|
||||
@ -9,38 +10,22 @@ use std::rc::Rc;
|
||||
/// A rule for addition could be implemented like so:
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::rule::{Rule, Input, InputVisitor};
|
||||
/// # use compute_graph::rule::{Rule, Input, InputVisitable};
|
||||
/// #[derive(InputVisitable)]
|
||||
/// struct Add(Input<i32>, Input<i32>);
|
||||
///
|
||||
/// impl Rule for Add {
|
||||
/// type Output = i32;
|
||||
///
|
||||
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// visitor.visit(&self.0);
|
||||
/// visitor.visit(&self.1);
|
||||
/// }
|
||||
///
|
||||
/// fn evaluate(&mut self) -> Self::Output {
|
||||
/// *self.0.value() + *self.1.value()
|
||||
/// *self.input_0() + *self.input_1()
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait Rule: 'static {
|
||||
pub trait Rule: InputVisitable + 'static {
|
||||
/// The type of the output value of the rule.
|
||||
type Output: NodeValue;
|
||||
|
||||
/// Visits all the [`Input`]s of this rule.
|
||||
///
|
||||
/// This method is called when the graph is built/modified in order to establish edges of the graph,
|
||||
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
|
||||
/// considered a dependency of the rule's node.
|
||||
///
|
||||
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
|
||||
/// doing so only permitted before the graph has been built or during the callback of
|
||||
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
|
||||
/// not be detected and will not be represented in the graph.
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||
|
||||
/// Produces the value of this rule using its inputs.
|
||||
///
|
||||
/// Note that the receiver of this method is a mutable reference to the rule itself. Rules are permitted
|
||||
@ -57,37 +42,46 @@ pub trait Rule: 'static {
|
||||
/// A rule produces a value for a graph node asynchronously.
|
||||
///
|
||||
/// ```rust
|
||||
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitor};
|
||||
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitable};
|
||||
/// # async fn do_async_work(_: i32) -> i32 { 0 }
|
||||
/// #[derive(InputVisitable)]
|
||||
/// struct AsyncMath(Input<i32>);
|
||||
///
|
||||
/// impl AsyncRule for AsyncMath {
|
||||
/// type Output = i32;
|
||||
///
|
||||
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||
/// visitor.visit(&self.0);
|
||||
/// }
|
||||
///
|
||||
/// async fn evaluate(&mut self) -> Self::Output {
|
||||
/// do_async_work(*self.0.value()).await
|
||||
/// do_async_work(*self.input_0()).await
|
||||
/// }
|
||||
/// }
|
||||
/// ```
|
||||
pub trait AsyncRule: 'static {
|
||||
pub trait AsyncRule: InputVisitable + 'static {
|
||||
/// The type of the output value of the rule.
|
||||
type Output: NodeValue;
|
||||
|
||||
/// Visits all the [`Input`]s of this rule.
|
||||
///
|
||||
/// See [`Rule::visit_inputs`] for additional details; the same caveats apply.
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||
|
||||
/// Asynchronously produces the value of this rule using its inputs.
|
||||
///
|
||||
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
|
||||
async fn evaluate(&mut self) -> Self::Output;
|
||||
}
|
||||
|
||||
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
|
||||
///
|
||||
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
|
||||
pub trait InputVisitable {
|
||||
/// Visits all the [`Input`]s of this rule.
|
||||
///
|
||||
/// This method is called when the graph is built/modified in order to establish edges of the graph,
|
||||
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
|
||||
/// considered a dependency of the rule's node.
|
||||
///
|
||||
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
|
||||
/// doing so only permitted before the graph has been built or during the callback of
|
||||
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
|
||||
/// not be detected and will not be represented in the graph.
|
||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||
}
|
||||
|
||||
/// A placeholder for the output of one node to be used as an input for another.
|
||||
///
|
||||
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
|
||||
@ -125,7 +119,8 @@ impl<T> Clone for Input<T> {
|
||||
|
||||
/// A type that can visit arbitrary [`Input`]s.
|
||||
///
|
||||
/// You generally do not implement this trait yourself. An implementation is provided to [`Rule::visit_inputs`].
|
||||
/// You generally do not implement this trait yourself. An implementation is provided to
|
||||
/// [`InputVisitable::visit_inputs`].
|
||||
pub trait InputVisitor {
|
||||
/// Visit an input whose value is of type `T`.
|
||||
fn visit<T>(&mut self, input: &Input<T>);
|
||||
@ -146,9 +141,11 @@ impl<T> ConstantRule<T> {
|
||||
impl<T: Clone + NodeValue> Rule for ConstantRule<T> {
|
||||
type Output = T;
|
||||
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
self.0.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> InputVisitable for ConstantRule<T> {
|
||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||
}
|
||||
|
12
crates/compute_graph_macros/Cargo.toml
Normal file
12
crates/compute_graph_macros/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "compute_graph_macros"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
syn = "2"
|
||||
quote = "1"
|
||||
proc-macro2 = "1"
|
134
crates/compute_graph_macros/src/lib.rs
Normal file
134
crates/compute_graph_macros/src/lib.rs
Normal file
@ -0,0 +1,134 @@
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro2::Literal;
|
||||
use quote::{format_ident, quote};
|
||||
use syn::{
|
||||
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument,
|
||||
PathArguments, Type,
|
||||
};
|
||||
|
||||
extern crate proc_macro;
|
||||
|
||||
/// Derive an implementation of the `InputVisitable` trait and helper methods.
|
||||
///
|
||||
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that
|
||||
/// calls `visit` on each field of the struct that is of type `Input<T>` for any T.
|
||||
///
|
||||
/// The macro also generates helper methods for accessing the value of each input less verbosely.
|
||||
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
|
||||
/// For named fields, the generated method name matches the field name. In both cases, the method
|
||||
/// returns a reference to the input value. As with the `Input::value` method, calling the helper methods
|
||||
/// before the referenced node has been evaluated is forbidden.
|
||||
#[proc_macro_derive(InputVisitable)]
|
||||
pub fn derive_rule(input: TokenStream) -> TokenStream {
|
||||
let input = parse_macro_input!(input as DeriveInput);
|
||||
if let Data::Struct(ref data) = input.data {
|
||||
derive_rule_struct(&input, data)
|
||||
} else {
|
||||
TokenStream::from(
|
||||
syn::Error::new(input.ident.span(), "Only structs can derive `Rule`")
|
||||
.to_compile_error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
|
||||
let name = &input.ident;
|
||||
|
||||
let visit_inputs = match data.fields {
|
||||
Fields::Named(ref named) => named
|
||||
.named
|
||||
.iter()
|
||||
.filter(|field| input_value_type(field).is_some())
|
||||
.map(|field| {
|
||||
let ident = field.ident.as_ref().unwrap();
|
||||
quote!(visitor.visit(&self.#ident);)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Fields::Unnamed(ref unnamed) => unnamed
|
||||
.unnamed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, field)| input_value_type(field).is_some())
|
||||
.map(|(i, _)| {
|
||||
let idx_lit = Literal::usize_unsuffixed(i);
|
||||
quote!(visitor.visit(&self.#idx_lit);)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Fields::Unit => vec![],
|
||||
};
|
||||
|
||||
let input_value_methods = match data.fields {
|
||||
Fields::Named(ref named) => named
|
||||
.named
|
||||
.iter()
|
||||
.filter_map(|field| input_value_type(field).map(|ty| (field, ty)))
|
||||
.map(|(field, ty)| {
|
||||
let ident = field.ident.as_ref().unwrap();
|
||||
quote!(
|
||||
|
||||
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
|
||||
self.#ident.value()
|
||||
}
|
||||
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Fields::Unnamed(ref unnamed) => unnamed
|
||||
.unnamed
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty)))
|
||||
.map(|(i, ty)| {
|
||||
let idx_lit = Literal::usize_unsuffixed(i);
|
||||
let ident = format_ident!("input_{i}");
|
||||
quote!(
|
||||
|
||||
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
|
||||
self.#idx_lit.value()
|
||||
}
|
||||
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
Fields::Unit => vec![],
|
||||
};
|
||||
|
||||
TokenStream::from(quote!(
|
||||
|
||||
impl ::compute_graph::rule::InputVisitable for #name {
|
||||
fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) {
|
||||
#(#visit_inputs)*
|
||||
}
|
||||
}
|
||||
|
||||
impl #name {
|
||||
#(#input_value_methods)*
|
||||
}
|
||||
|
||||
))
|
||||
}
|
||||
|
||||
fn input_value_type(field: &Field) -> Option<&Type> {
|
||||
if let Type::Path(ref path) = field.ty {
|
||||
let last_segment = path.path.segments.last().unwrap();
|
||||
if last_segment.ident == "Input" {
|
||||
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
|
||||
if args.args.len() == 1 {
|
||||
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
|
||||
Some(ty)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
7
crates/derive_test/Cargo.toml
Normal file
7
crates/derive_test/Cargo.toml
Normal file
@ -0,0 +1,7 @@
|
||||
[package]
|
||||
name = "derive_test"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
compute_graph = { path = "../compute_graph" }
|
52
crates/derive_test/src/lib.rs
Normal file
52
crates/derive_test/src/lib.rs
Normal file
@ -0,0 +1,52 @@
|
||||
use compute_graph::rule::{Input, InputVisitable, Rule};
|
||||
|
||||
#[derive(InputVisitable)]
|
||||
struct Add(Input<i32>, Input<i32>, i32);
|
||||
|
||||
impl Rule for Add {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
*self.input_0() + *self.input_1() + self.2
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(InputVisitable)]
|
||||
struct Add2 {
|
||||
a: Input<i32>,
|
||||
b: Input<i32>,
|
||||
c: i32,
|
||||
}
|
||||
|
||||
impl Rule for Add2 {
|
||||
type Output = i32;
|
||||
fn evaluate(&mut self) -> Self::Output {
|
||||
*self.a() + *self.b() + self.c
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use compute_graph::builder::GraphBuilder;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_add() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let a = builder.add_value(1);
|
||||
let b = builder.add_value(2);
|
||||
builder.set_output(Add(a, b, 3));
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add2() {
|
||||
let mut builder = GraphBuilder::new();
|
||||
let a = builder.add_value(1);
|
||||
let b = builder.add_value(2);
|
||||
builder.set_output(Add2 { a, b, c: 3 });
|
||||
let mut graph = builder.build().unwrap();
|
||||
assert_eq!(*graph.evaluate(), 6);
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user