Derive macro
This commit is contained in:
parent
36bcbe3c9c
commit
08a4bf87dc
17
Cargo.lock
generated
17
Cargo.lock
generated
@ -360,10 +360,20 @@ dependencies = [
|
|||||||
name = "compute_graph"
|
name = "compute_graph"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"compute_graph_macros",
|
||||||
"petgraph",
|
"petgraph",
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "compute_graph_macros"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.85",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "core-foundation"
|
name = "core-foundation"
|
||||||
version = "0.9.3"
|
version = "0.9.3"
|
||||||
@ -521,6 +531,13 @@ dependencies = [
|
|||||||
"syn 1.0.99",
|
"syn 1.0.99",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "derive_test"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"compute_graph",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "digest"
|
name = "digest"
|
||||||
version = "0.10.3"
|
version = "0.10.3"
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
workspace = { members = ["crates/compute_graph"] }
|
workspace = { members = ["crates/compute_graph", "crates/compute_graph_macros", "crates/derive_test"] }
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "v6"
|
name = "v6"
|
||||||
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
compute_graph_macros = { path = "../compute_graph_macros" }
|
||||||
petgraph = "0.6.5"
|
petgraph = "0.6.5"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
@ -92,17 +92,15 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
|||||||
/// the value of the node can be replaced, invalidating the node in the process.
|
/// the value of the node can be replaced, invalidating the node in the process.
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||||
/// let mut builder = GraphBuilder::new();
|
/// let mut builder = GraphBuilder::new();
|
||||||
/// let (input, signal) = builder.add_invalidatable_value(0);
|
/// let (input, signal) = builder.add_invalidatable_value(0);
|
||||||
|
/// # #[derive(InputVisitable)]
|
||||||
/// # struct Double(Input<i32>);
|
/// # struct Double(Input<i32>);
|
||||||
/// # impl Rule for Double {
|
/// # impl Rule for Double {
|
||||||
/// # type Output = i32;
|
/// # type Output = i32;
|
||||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
||||||
/// # visitor.visit(&self.0);
|
|
||||||
/// # }
|
|
||||||
/// # fn evaluate(&mut self) -> i32 {
|
/// # fn evaluate(&mut self) -> i32 {
|
||||||
/// # *self.0.value() * 2
|
/// # *self.input_0() * 2
|
||||||
/// # }
|
/// # }
|
||||||
/// # }
|
/// # }
|
||||||
/// builder.set_output(Double(input));
|
/// builder.set_output(Double(input));
|
||||||
@ -139,26 +137,24 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
|
|||||||
/// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated.
|
/// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated.
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||||
/// let mut builder = GraphBuilder::new();
|
/// let mut builder = GraphBuilder::new();
|
||||||
|
/// # #[derive(InputVisitable)]
|
||||||
/// # struct IncrementAfterEvaluate(i32);
|
/// # struct IncrementAfterEvaluate(i32);
|
||||||
/// # impl Rule for IncrementAfterEvaluate {
|
/// # impl Rule for IncrementAfterEvaluate {
|
||||||
/// # type Output = i32;
|
/// # type Output = i32;
|
||||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {}
|
|
||||||
/// # fn evaluate(&mut self) -> i32 {
|
/// # fn evaluate(&mut self) -> i32 {
|
||||||
/// # let result = self.0;
|
/// # let result = self.0;
|
||||||
/// # self.0 += 1;
|
/// # self.0 += 1;
|
||||||
/// # result
|
/// # result
|
||||||
/// # }
|
/// # }
|
||||||
/// # }
|
/// # }
|
||||||
|
/// # #[derive(InputVisitable)]
|
||||||
/// # struct Double(Input<i32>);
|
/// # struct Double(Input<i32>);
|
||||||
/// # impl Rule for Double {
|
/// # impl Rule for Double {
|
||||||
/// # type Output = i32;
|
/// # type Output = i32;
|
||||||
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
||||||
/// # visitor.visit(&self.0);
|
|
||||||
/// # }
|
|
||||||
/// # fn evaluate(&mut self) -> i32 {
|
/// # fn evaluate(&mut self) -> i32 {
|
||||||
/// # *self.0.value() * 2
|
/// # *self.input_0() * 2
|
||||||
/// # }
|
/// # }
|
||||||
/// # }
|
/// # }
|
||||||
/// let (input, signal) = builder.add_invalidatable_rule(IncrementAfterEvaluate(1));
|
/// let (input, signal) = builder.add_invalidatable_rule(IncrementAfterEvaluate(1));
|
||||||
|
@ -7,19 +7,16 @@
|
|||||||
//! dependencies. For example, an arithmetic operation can be implemented like so:
|
//! dependencies. For example, an arithmetic operation can be implemented like so:
|
||||||
//!
|
//!
|
||||||
//! ```rust
|
//! ```rust
|
||||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||||
//! let mut builder = GraphBuilder::new();
|
//! let mut builder = GraphBuilder::new();
|
||||||
//! let a = builder.add_value(1);
|
//! let a = builder.add_value(1);
|
||||||
//! let b = builder.add_value(2);
|
//! let b = builder.add_value(2);
|
||||||
|
//! # #[derive(InputVisitable)]
|
||||||
//! # struct Add(Input<i32>, Input<i32>);
|
//! # struct Add(Input<i32>, Input<i32>);
|
||||||
//! # impl Rule for Add {
|
//! # impl Rule for Add {
|
||||||
//! # type Output = i32;
|
//! # type Output = i32;
|
||||||
//! # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
||||||
//! # visitor.visit(&self.0);
|
|
||||||
//! # visitor.visit(&self.1);
|
|
||||||
//! # }
|
|
||||||
//! # fn evaluate(&mut self) -> i32 {
|
//! # fn evaluate(&mut self) -> i32 {
|
||||||
//! # *self.0.value() + *self.1.value()
|
//! # *self.input_0() + *self.input_1()
|
||||||
//! # }
|
//! # }
|
||||||
//! # }
|
//! # }
|
||||||
//! builder.set_output(Add(a, b));
|
//! builder.set_output(Add(a, b));
|
||||||
@ -33,17 +30,14 @@
|
|||||||
//! The `Add` rule is implemented as follows:
|
//! The `Add` rule is implemented as follows:
|
||||||
//!
|
//!
|
||||||
//! ```rust
|
//! ```rust
|
||||||
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
|
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
|
||||||
|
//! #[derive(InputVisitable)]
|
||||||
//! struct Add(Input<i32>, Input<i32>);
|
//! struct Add(Input<i32>, Input<i32>);
|
||||||
//!
|
//!
|
||||||
//! impl Rule for Add {
|
//! impl Rule for Add {
|
||||||
//! type Output = i32;
|
//! type Output = i32;
|
||||||
//! fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
||||||
//! visitor.visit(&self.0);
|
|
||||||
//! visitor.visit(&self.1);
|
|
||||||
//! }
|
|
||||||
//! fn evaluate(&mut self) -> i32 {
|
//! fn evaluate(&mut self) -> i32 {
|
||||||
//! *self.0.value() + *self.1.value()
|
//! *self.input_0() + *self.input_1()
|
||||||
//! }
|
//! }
|
||||||
//! }
|
//! }
|
||||||
//! ```
|
//! ```
|
||||||
@ -373,7 +367,7 @@ impl<V> Clone for ValueInvalidationSignal<V> {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::rule::ConstantRule;
|
use crate::rule::{ConstantRule, InputVisitable};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn rule_output_with_no_inputs() {
|
fn rule_output_with_no_inputs() {
|
||||||
@ -393,11 +387,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Double(Input<i32>);
|
struct Double(Input<i32>);
|
||||||
impl Rule for Double {
|
impl InputVisitable for Double {
|
||||||
type Output = i32;
|
|
||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
visitor.visit(&self.0);
|
visitor.visit(&self.0);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
impl Rule for Double {
|
||||||
|
type Output = i32;
|
||||||
fn evaluate(&mut self) -> i32 {
|
fn evaluate(&mut self) -> i32 {
|
||||||
*self.0.value() * 2
|
*self.0.value() * 2
|
||||||
}
|
}
|
||||||
@ -421,9 +417,11 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Inc(i32);
|
struct Inc(i32);
|
||||||
|
impl InputVisitable for Inc {
|
||||||
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||||
|
}
|
||||||
impl Rule for Inc {
|
impl Rule for Inc {
|
||||||
type Output = i32;
|
type Output = i32;
|
||||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
||||||
fn evaluate(&mut self) -> i32 {
|
fn evaluate(&mut self) -> i32 {
|
||||||
self.0 += 1;
|
self.0 += 1;
|
||||||
return self.0;
|
return self.0;
|
||||||
@ -445,12 +443,14 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Add(Input<i32>, Input<i32>);
|
struct Add(Input<i32>, Input<i32>);
|
||||||
impl Rule for Add {
|
impl InputVisitable for Add {
|
||||||
type Output = i32;
|
|
||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
visitor.visit(&self.0);
|
visitor.visit(&self.0);
|
||||||
visitor.visit(&self.1);
|
visitor.visit(&self.1);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
impl Rule for Add {
|
||||||
|
type Output = i32;
|
||||||
fn evaluate(&mut self) -> i32 {
|
fn evaluate(&mut self) -> i32 {
|
||||||
*self.0.value() + *self.1.value()
|
*self.0.value() + *self.1.value()
|
||||||
}
|
}
|
||||||
@ -489,13 +489,15 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
|
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
|
||||||
impl Rule for DeferredInput {
|
impl InputVisitable for DeferredInput {
|
||||||
type Output = i32;
|
|
||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
let borrowed = self.0.borrow();
|
let borrowed = self.0.borrow();
|
||||||
let input = borrowed.as_ref().unwrap();
|
let input = borrowed.as_ref().unwrap();
|
||||||
visitor.visit(input);
|
visitor.visit(input);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
impl Rule for DeferredInput {
|
||||||
|
type Output = i32;
|
||||||
fn evaluate(&mut self) -> i32 {
|
fn evaluate(&mut self) -> i32 {
|
||||||
*self.0.borrow().as_ref().unwrap().value()
|
*self.0.borrow().as_ref().unwrap().value()
|
||||||
}
|
}
|
||||||
@ -560,9 +562,11 @@ mod tests {
|
|||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn async_rule() {
|
async fn async_rule() {
|
||||||
struct AsyncConst(i32);
|
struct AsyncConst(i32);
|
||||||
|
impl InputVisitable for AsyncConst {
|
||||||
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||||
|
}
|
||||||
impl AsyncRule for AsyncConst {
|
impl AsyncRule for AsyncConst {
|
||||||
type Output = i32;
|
type Output = i32;
|
||||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
||||||
async fn evaluate(&mut self) -> i32 {
|
async fn evaluate(&mut self) -> i32 {
|
||||||
self.0
|
self.0
|
||||||
}
|
}
|
||||||
@ -578,9 +582,11 @@ mod tests {
|
|||||||
#[derive(PartialEq, Debug)]
|
#[derive(PartialEq, Debug)]
|
||||||
struct NonCloneable;
|
struct NonCloneable;
|
||||||
struct Output;
|
struct Output;
|
||||||
|
impl InputVisitable for Output {
|
||||||
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||||
|
}
|
||||||
impl Rule for Output {
|
impl Rule for Output {
|
||||||
type Output = NonCloneable;
|
type Output = NonCloneable;
|
||||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
||||||
fn evaluate(&mut self) -> Self::Output {
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
NonCloneable
|
NonCloneable
|
||||||
}
|
}
|
||||||
@ -596,11 +602,13 @@ mod tests {
|
|||||||
let mut builder = GraphBuilder::new();
|
let mut builder = GraphBuilder::new();
|
||||||
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
|
||||||
struct IncAdd(Input<i32>, i32);
|
struct IncAdd(Input<i32>, i32);
|
||||||
impl Rule for IncAdd {
|
impl InputVisitable for IncAdd {
|
||||||
type Output = i32;
|
|
||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
||||||
visitor.visit(&self.0);
|
visitor.visit(&self.0);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
impl Rule for IncAdd {
|
||||||
|
type Output = i32;
|
||||||
fn evaluate(&mut self) -> Self::Output {
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
self.1 += 1;
|
self.1 += 1;
|
||||||
*self.0.value() + self.1
|
*self.0.value() + self.1
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::node::NodeValue;
|
use crate::node::NodeValue;
|
||||||
use crate::NodeId;
|
use crate::NodeId;
|
||||||
|
pub use compute_graph_macros::InputVisitable;
|
||||||
use std::cell::{Ref, RefCell};
|
use std::cell::{Ref, RefCell};
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
@ -9,38 +10,22 @@ use std::rc::Rc;
|
|||||||
/// A rule for addition could be implemented like so:
|
/// A rule for addition could be implemented like so:
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// # use compute_graph::rule::{Rule, Input, InputVisitor};
|
/// # use compute_graph::rule::{Rule, Input, InputVisitable};
|
||||||
|
/// #[derive(InputVisitable)]
|
||||||
/// struct Add(Input<i32>, Input<i32>);
|
/// struct Add(Input<i32>, Input<i32>);
|
||||||
///
|
///
|
||||||
/// impl Rule for Add {
|
/// impl Rule for Add {
|
||||||
/// type Output = i32;
|
/// type Output = i32;
|
||||||
///
|
///
|
||||||
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
||||||
/// visitor.visit(&self.0);
|
|
||||||
/// visitor.visit(&self.1);
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// fn evaluate(&mut self) -> Self::Output {
|
/// fn evaluate(&mut self) -> Self::Output {
|
||||||
/// *self.0.value() + *self.1.value()
|
/// *self.input_0() + *self.input_1()
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub trait Rule: 'static {
|
pub trait Rule: InputVisitable + 'static {
|
||||||
/// The type of the output value of the rule.
|
/// The type of the output value of the rule.
|
||||||
type Output: NodeValue;
|
type Output: NodeValue;
|
||||||
|
|
||||||
/// Visits all the [`Input`]s of this rule.
|
|
||||||
///
|
|
||||||
/// This method is called when the graph is built/modified in order to establish edges of the graph,
|
|
||||||
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
|
|
||||||
/// considered a dependency of the rule's node.
|
|
||||||
///
|
|
||||||
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
|
|
||||||
/// doing so only permitted before the graph has been built or during the callback of
|
|
||||||
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
|
|
||||||
/// not be detected and will not be represented in the graph.
|
|
||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
|
||||||
|
|
||||||
/// Produces the value of this rule using its inputs.
|
/// Produces the value of this rule using its inputs.
|
||||||
///
|
///
|
||||||
/// Note that the receiver of this method is a mutable reference to the rule itself. Rules are permitted
|
/// Note that the receiver of this method is a mutable reference to the rule itself. Rules are permitted
|
||||||
@ -57,37 +42,46 @@ pub trait Rule: 'static {
|
|||||||
/// A rule produces a value for a graph node asynchronously.
|
/// A rule produces a value for a graph node asynchronously.
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitor};
|
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitable};
|
||||||
/// # async fn do_async_work(_: i32) -> i32 { 0 }
|
/// # async fn do_async_work(_: i32) -> i32 { 0 }
|
||||||
|
/// #[derive(InputVisitable)]
|
||||||
/// struct AsyncMath(Input<i32>);
|
/// struct AsyncMath(Input<i32>);
|
||||||
///
|
///
|
||||||
/// impl AsyncRule for AsyncMath {
|
/// impl AsyncRule for AsyncMath {
|
||||||
/// type Output = i32;
|
/// type Output = i32;
|
||||||
///
|
///
|
||||||
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
|
|
||||||
/// visitor.visit(&self.0);
|
|
||||||
/// }
|
|
||||||
///
|
|
||||||
/// async fn evaluate(&mut self) -> Self::Output {
|
/// async fn evaluate(&mut self) -> Self::Output {
|
||||||
/// do_async_work(*self.0.value()).await
|
/// do_async_work(*self.input_0()).await
|
||||||
/// }
|
/// }
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
pub trait AsyncRule: 'static {
|
pub trait AsyncRule: InputVisitable + 'static {
|
||||||
/// The type of the output value of the rule.
|
/// The type of the output value of the rule.
|
||||||
type Output: NodeValue;
|
type Output: NodeValue;
|
||||||
|
|
||||||
/// Visits all the [`Input`]s of this rule.
|
|
||||||
///
|
|
||||||
/// See [`Rule::visit_inputs`] for additional details; the same caveats apply.
|
|
||||||
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
|
||||||
|
|
||||||
/// Asynchronously produces the value of this rule using its inputs.
|
/// Asynchronously produces the value of this rule using its inputs.
|
||||||
///
|
///
|
||||||
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
|
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
|
||||||
async fn evaluate(&mut self) -> Self::Output;
|
async fn evaluate(&mut self) -> Self::Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
|
||||||
|
///
|
||||||
|
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
|
||||||
|
pub trait InputVisitable {
|
||||||
|
/// Visits all the [`Input`]s of this rule.
|
||||||
|
///
|
||||||
|
/// This method is called when the graph is built/modified in order to establish edges of the graph,
|
||||||
|
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
|
||||||
|
/// considered a dependency of the rule's node.
|
||||||
|
///
|
||||||
|
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
|
||||||
|
/// doing so only permitted before the graph has been built or during the callback of
|
||||||
|
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
|
||||||
|
/// not be detected and will not be represented in the graph.
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
|
||||||
|
}
|
||||||
|
|
||||||
/// A placeholder for the output of one node to be used as an input for another.
|
/// A placeholder for the output of one node to be used as an input for another.
|
||||||
///
|
///
|
||||||
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
|
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
|
||||||
@ -125,7 +119,8 @@ impl<T> Clone for Input<T> {
|
|||||||
|
|
||||||
/// A type that can visit arbitrary [`Input`]s.
|
/// A type that can visit arbitrary [`Input`]s.
|
||||||
///
|
///
|
||||||
/// You generally do not implement this trait yourself. An implementation is provided to [`Rule::visit_inputs`].
|
/// You generally do not implement this trait yourself. An implementation is provided to
|
||||||
|
/// [`InputVisitable::visit_inputs`].
|
||||||
pub trait InputVisitor {
|
pub trait InputVisitor {
|
||||||
/// Visit an input whose value is of type `T`.
|
/// Visit an input whose value is of type `T`.
|
||||||
fn visit<T>(&mut self, input: &Input<T>);
|
fn visit<T>(&mut self, input: &Input<T>);
|
||||||
@ -146,9 +141,11 @@ impl<T> ConstantRule<T> {
|
|||||||
impl<T: Clone + NodeValue> Rule for ConstantRule<T> {
|
impl<T: Clone + NodeValue> Rule for ConstantRule<T> {
|
||||||
type Output = T;
|
type Output = T;
|
||||||
|
|
||||||
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
|
||||||
|
|
||||||
fn evaluate(&mut self) -> Self::Output {
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
self.0.clone()
|
self.0.clone()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T> InputVisitable for ConstantRule<T> {
|
||||||
|
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
|
||||||
|
}
|
||||||
|
12
crates/compute_graph_macros/Cargo.toml
Normal file
12
crates/compute_graph_macros/Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "compute_graph_macros"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
proc-macro = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
syn = "2"
|
||||||
|
quote = "1"
|
||||||
|
proc-macro2 = "1"
|
134
crates/compute_graph_macros/src/lib.rs
Normal file
134
crates/compute_graph_macros/src/lib.rs
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
use proc_macro::TokenStream;
|
||||||
|
use proc_macro2::Literal;
|
||||||
|
use quote::{format_ident, quote};
|
||||||
|
use syn::{
|
||||||
|
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument,
|
||||||
|
PathArguments, Type,
|
||||||
|
};
|
||||||
|
|
||||||
|
extern crate proc_macro;
|
||||||
|
|
||||||
|
/// Derive an implementation of the `InputVisitable` trait and helper methods.
|
||||||
|
///
|
||||||
|
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that
|
||||||
|
/// calls `visit` on each field of the struct that is of type `Input<T>` for any T.
|
||||||
|
///
|
||||||
|
/// The macro also generates helper methods for accessing the value of each input less verbosely.
|
||||||
|
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
|
||||||
|
/// For named fields, the generated method name matches the field name. In both cases, the method
|
||||||
|
/// returns a reference to the input value. As with the `Input::value` method, calling the helper methods
|
||||||
|
/// before the referenced node has been evaluated is forbidden.
|
||||||
|
#[proc_macro_derive(InputVisitable)]
|
||||||
|
pub fn derive_rule(input: TokenStream) -> TokenStream {
|
||||||
|
let input = parse_macro_input!(input as DeriveInput);
|
||||||
|
if let Data::Struct(ref data) = input.data {
|
||||||
|
derive_rule_struct(&input, data)
|
||||||
|
} else {
|
||||||
|
TokenStream::from(
|
||||||
|
syn::Error::new(input.ident.span(), "Only structs can derive `Rule`")
|
||||||
|
.to_compile_error(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
|
||||||
|
let name = &input.ident;
|
||||||
|
|
||||||
|
let visit_inputs = match data.fields {
|
||||||
|
Fields::Named(ref named) => named
|
||||||
|
.named
|
||||||
|
.iter()
|
||||||
|
.filter(|field| input_value_type(field).is_some())
|
||||||
|
.map(|field| {
|
||||||
|
let ident = field.ident.as_ref().unwrap();
|
||||||
|
quote!(visitor.visit(&self.#ident);)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
Fields::Unnamed(ref unnamed) => unnamed
|
||||||
|
.unnamed
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter(|(_, field)| input_value_type(field).is_some())
|
||||||
|
.map(|(i, _)| {
|
||||||
|
let idx_lit = Literal::usize_unsuffixed(i);
|
||||||
|
quote!(visitor.visit(&self.#idx_lit);)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
Fields::Unit => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let input_value_methods = match data.fields {
|
||||||
|
Fields::Named(ref named) => named
|
||||||
|
.named
|
||||||
|
.iter()
|
||||||
|
.filter_map(|field| input_value_type(field).map(|ty| (field, ty)))
|
||||||
|
.map(|(field, ty)| {
|
||||||
|
let ident = field.ident.as_ref().unwrap();
|
||||||
|
quote!(
|
||||||
|
|
||||||
|
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
|
||||||
|
self.#ident.value()
|
||||||
|
}
|
||||||
|
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
Fields::Unnamed(ref unnamed) => unnamed
|
||||||
|
.unnamed
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty)))
|
||||||
|
.map(|(i, ty)| {
|
||||||
|
let idx_lit = Literal::usize_unsuffixed(i);
|
||||||
|
let ident = format_ident!("input_{i}");
|
||||||
|
quote!(
|
||||||
|
|
||||||
|
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
|
||||||
|
self.#idx_lit.value()
|
||||||
|
}
|
||||||
|
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
Fields::Unit => vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
TokenStream::from(quote!(
|
||||||
|
|
||||||
|
impl ::compute_graph::rule::InputVisitable for #name {
|
||||||
|
fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) {
|
||||||
|
#(#visit_inputs)*
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl #name {
|
||||||
|
#(#input_value_methods)*
|
||||||
|
}
|
||||||
|
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_value_type(field: &Field) -> Option<&Type> {
|
||||||
|
if let Type::Path(ref path) = field.ty {
|
||||||
|
let last_segment = path.path.segments.last().unwrap();
|
||||||
|
if last_segment.ident == "Input" {
|
||||||
|
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
|
||||||
|
if args.args.len() == 1 {
|
||||||
|
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
|
||||||
|
Some(ty)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
7
crates/derive_test/Cargo.toml
Normal file
7
crates/derive_test/Cargo.toml
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
[package]
|
||||||
|
name = "derive_test"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
compute_graph = { path = "../compute_graph" }
|
52
crates/derive_test/src/lib.rs
Normal file
52
crates/derive_test/src/lib.rs
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
use compute_graph::rule::{Input, InputVisitable, Rule};
|
||||||
|
|
||||||
|
#[derive(InputVisitable)]
|
||||||
|
struct Add(Input<i32>, Input<i32>, i32);
|
||||||
|
|
||||||
|
impl Rule for Add {
|
||||||
|
type Output = i32;
|
||||||
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
|
*self.input_0() + *self.input_1() + self.2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(InputVisitable)]
|
||||||
|
struct Add2 {
|
||||||
|
a: Input<i32>,
|
||||||
|
b: Input<i32>,
|
||||||
|
c: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Rule for Add2 {
|
||||||
|
type Output = i32;
|
||||||
|
fn evaluate(&mut self) -> Self::Output {
|
||||||
|
*self.a() + *self.b() + self.c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use compute_graph::builder::GraphBuilder;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add() {
|
||||||
|
let mut builder = GraphBuilder::new();
|
||||||
|
let a = builder.add_value(1);
|
||||||
|
let b = builder.add_value(2);
|
||||||
|
builder.set_output(Add(a, b, 3));
|
||||||
|
let mut graph = builder.build().unwrap();
|
||||||
|
assert_eq!(*graph.evaluate(), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add2() {
|
||||||
|
let mut builder = GraphBuilder::new();
|
||||||
|
let a = builder.add_value(1);
|
||||||
|
let b = builder.add_value(2);
|
||||||
|
builder.set_output(Add2 { a, b, c: 3 });
|
||||||
|
let mut graph = builder.build().unwrap();
|
||||||
|
assert_eq!(*graph.evaluate(), 6);
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user