Derive macro

This commit is contained in:
Shadowfacts 2024-11-03 01:22:19 -04:00
parent 36bcbe3c9c
commit 08a4bf87dc
10 changed files with 295 additions and 71 deletions

17
Cargo.lock generated
View File

@ -360,10 +360,20 @@ dependencies = [
name = "compute_graph"
version = "0.1.0"
dependencies = [
"compute_graph_macros",
"petgraph",
"tokio",
]
[[package]]
name = "compute_graph_macros"
version = "0.1.0"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.85",
]
[[package]]
name = "core-foundation"
version = "0.9.3"
@ -521,6 +531,13 @@ dependencies = [
"syn 1.0.99",
]
[[package]]
name = "derive_test"
version = "0.1.0"
dependencies = [
"compute_graph",
]
[[package]]
name = "digest"
version = "0.10.3"

View File

@ -1,4 +1,4 @@
workspace = { members = ["crates/compute_graph"] }
workspace = { members = ["crates/compute_graph", "crates/compute_graph_macros", "crates/derive_test"] }
[package]
name = "v6"

View File

@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
compute_graph_macros = { path = "../compute_graph_macros" }
petgraph = "0.6.5"
[dev-dependencies]

View File

@ -92,17 +92,15 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
/// the value of the node can be replaced, invalidating the node in the process.
///
/// ```rust
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
/// let mut builder = GraphBuilder::new();
/// let (input, signal) = builder.add_invalidatable_value(0);
/// # #[derive(InputVisitable)]
/// # struct Double(Input<i32>);
/// # impl Rule for Double {
/// # type Output = i32;
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
/// # visitor.visit(&self.0);
/// # }
/// # fn evaluate(&mut self) -> i32 {
/// # *self.0.value() * 2
/// # *self.input_0() * 2
/// # }
/// # }
/// builder.set_output(Double(input));
@ -139,26 +137,24 @@ impl<O: 'static, S: Synchronicity> GraphBuilder<O, S> {
/// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated.
///
/// ```rust
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
/// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
/// let mut builder = GraphBuilder::new();
/// # #[derive(InputVisitable)]
/// # struct IncrementAfterEvaluate(i32);
/// # impl Rule for IncrementAfterEvaluate {
/// # type Output = i32;
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {}
/// # fn evaluate(&mut self) -> i32 {
/// # let result = self.0;
/// # self.0 += 1;
/// # result
/// # }
/// # }
/// # #[derive(InputVisitable)]
/// # struct Double(Input<i32>);
/// # impl Rule for Double {
/// # type Output = i32;
/// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
/// # visitor.visit(&self.0);
/// # }
/// # fn evaluate(&mut self) -> i32 {
/// # *self.0.value() * 2
/// # *self.input_0() * 2
/// # }
/// # }
/// let (input, signal) = builder.add_invalidatable_rule(IncrementAfterEvaluate(1));

View File

@ -7,19 +7,16 @@
//! dependencies. For example, an arithmetic operation can be implemented like so:
//!
//! ```rust
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
//! let mut builder = GraphBuilder::new();
//! let a = builder.add_value(1);
//! let b = builder.add_value(2);
//! # #[derive(InputVisitable)]
//! # struct Add(Input<i32>, Input<i32>);
//! # impl Rule for Add {
//! # type Output = i32;
//! # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
//! # visitor.visit(&self.0);
//! # visitor.visit(&self.1);
//! # }
//! # fn evaluate(&mut self) -> i32 {
//! # *self.0.value() + *self.1.value()
//! # *self.input_0() + *self.input_1()
//! # }
//! # }
//! builder.set_output(Add(a, b));
@ -33,17 +30,14 @@
//! The `Add` rule is implemented as follows:
//!
//! ```rust
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}};
//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}};
//! #[derive(InputVisitable)]
//! struct Add(Input<i32>, Input<i32>);
//!
//! impl Rule for Add {
//! type Output = i32;
//! fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
//! visitor.visit(&self.0);
//! visitor.visit(&self.1);
//! }
//! fn evaluate(&mut self) -> i32 {
//! *self.0.value() + *self.1.value()
//! *self.input_0() + *self.input_1()
//! }
//! }
//! ```
@ -373,7 +367,7 @@ impl<V> Clone for ValueInvalidationSignal<V> {
#[cfg(test)]
mod tests {
use super::*;
use crate::rule::ConstantRule;
use crate::rule::{ConstantRule, InputVisitable};
#[test]
fn rule_output_with_no_inputs() {
@ -393,11 +387,13 @@ mod tests {
}
struct Double(Input<i32>);
impl Rule for Double {
type Output = i32;
impl InputVisitable for Double {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl Rule for Double {
type Output = i32;
fn evaluate(&mut self) -> i32 {
*self.0.value() * 2
}
@ -421,9 +417,11 @@ mod tests {
}
struct Inc(i32);
impl InputVisitable for Inc {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
}
impl Rule for Inc {
type Output = i32;
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> i32 {
self.0 += 1;
return self.0;
@ -445,12 +443,14 @@ mod tests {
}
struct Add(Input<i32>, Input<i32>);
impl Rule for Add {
type Output = i32;
impl InputVisitable for Add {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
visitor.visit(&self.1);
}
}
impl Rule for Add {
type Output = i32;
fn evaluate(&mut self) -> i32 {
*self.0.value() + *self.1.value()
}
@ -489,13 +489,15 @@ mod tests {
}
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
impl Rule for DeferredInput {
type Output = i32;
impl InputVisitable for DeferredInput {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
let borrowed = self.0.borrow();
let input = borrowed.as_ref().unwrap();
visitor.visit(input);
}
}
impl Rule for DeferredInput {
type Output = i32;
fn evaluate(&mut self) -> i32 {
*self.0.borrow().as_ref().unwrap().value()
}
@ -560,9 +562,11 @@ mod tests {
#[tokio::test]
async fn async_rule() {
struct AsyncConst(i32);
impl InputVisitable for AsyncConst {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
}
impl AsyncRule for AsyncConst {
type Output = i32;
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
async fn evaluate(&mut self) -> i32 {
self.0
}
@ -578,9 +582,11 @@ mod tests {
#[derive(PartialEq, Debug)]
struct NonCloneable;
struct Output;
impl InputVisitable for Output {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
}
impl Rule for Output {
type Output = NonCloneable;
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> Self::Output {
NonCloneable
}
@ -596,11 +602,13 @@ mod tests {
let mut builder = GraphBuilder::new();
let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0));
struct IncAdd(Input<i32>, i32);
impl Rule for IncAdd {
type Output = i32;
impl InputVisitable for IncAdd {
fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
visitor.visit(&self.0);
}
}
impl Rule for IncAdd {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
self.1 += 1;
*self.0.value() + self.1

View File

@ -1,5 +1,6 @@
use crate::node::NodeValue;
use crate::NodeId;
pub use compute_graph_macros::InputVisitable;
use std::cell::{Ref, RefCell};
use std::ops::Deref;
use std::rc::Rc;
@ -9,38 +10,22 @@ use std::rc::Rc;
/// A rule for addition could be implemented like so:
///
/// ```rust
/// # use compute_graph::rule::{Rule, Input, InputVisitor};
/// # use compute_graph::rule::{Rule, Input, InputVisitable};
/// #[derive(InputVisitable)]
/// struct Add(Input<i32>, Input<i32>);
///
/// impl Rule for Add {
/// type Output = i32;
///
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
/// visitor.visit(&self.0);
/// visitor.visit(&self.1);
/// }
///
/// fn evaluate(&mut self) -> Self::Output {
/// *self.0.value() + *self.1.value()
/// *self.input_0() + *self.input_1()
/// }
/// }
/// ```
pub trait Rule: 'static {
pub trait Rule: InputVisitable + 'static {
/// The type of the output value of the rule.
type Output: NodeValue;
/// Visits all the [`Input`]s of this rule.
///
/// This method is called when the graph is built/modified in order to establish edges of the graph,
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
/// considered a dependency of the rule's node.
///
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
/// doing so only permitted before the graph has been built or during the callback of
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
/// not be detected and will not be represented in the graph.
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
/// Produces the value of this rule using its inputs.
///
/// Note that the receiver of this method is a mutable reference to the rule itself. Rules are permitted
@ -57,37 +42,46 @@ pub trait Rule: 'static {
/// A rule produces a value for a graph node asynchronously.
///
/// ```rust
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitor};
/// # use compute_graph::rule::{AsyncRule, Input, InputVisitable};
/// # async fn do_async_work(_: i32) -> i32 { 0 }
/// #[derive(InputVisitable)]
/// struct AsyncMath(Input<i32>);
///
/// impl AsyncRule for AsyncMath {
/// type Output = i32;
///
/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) {
/// visitor.visit(&self.0);
/// }
///
/// async fn evaluate(&mut self) -> Self::Output {
/// do_async_work(*self.0.value()).await
/// do_async_work(*self.input_0()).await
/// }
/// }
/// ```
pub trait AsyncRule: 'static {
pub trait AsyncRule: InputVisitable + 'static {
/// The type of the output value of the rule.
type Output: NodeValue;
/// Visits all the [`Input`]s of this rule.
///
/// See [`Rule::visit_inputs`] for additional details; the same caveats apply.
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
/// Asynchronously produces the value of this rule using its inputs.
///
/// See [`Rule::evaluate`] for additional details; the same considerations apply.
async fn evaluate(&mut self) -> Self::Output;
}
/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited.
///
/// The implementation of this trait can generally be derived using [`derive@InputVisitable`].
pub trait InputVisitable {
/// Visits all the [`Input`]s of this rule.
///
/// This method is called when the graph is built/modified in order to establish edges of the graph,
/// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is
/// considered a dependency of the rule's node.
///
/// While it is permitted for the dependencies of a rule to change after it has been added to the graph,
/// doing so only permitted before the graph has been built or during the callback of
/// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will
/// not be detected and will not be represented in the graph.
fn visit_inputs(&self, visitor: &mut impl InputVisitor);
}
/// A placeholder for the output of one node to be used as an input for another.
///
/// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`).
@ -125,7 +119,8 @@ impl<T> Clone for Input<T> {
/// A type that can visit arbitrary [`Input`]s.
///
/// You generally do not implement this trait yourself. An implementation is provided to [`Rule::visit_inputs`].
/// You generally do not implement this trait yourself. An implementation is provided to
/// [`InputVisitable::visit_inputs`].
pub trait InputVisitor {
/// Visit an input whose value is of type `T`.
fn visit<T>(&mut self, input: &Input<T>);
@ -146,9 +141,11 @@ impl<T> ConstantRule<T> {
impl<T: Clone + NodeValue> Rule for ConstantRule<T> {
type Output = T;
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> Self::Output {
self.0.clone()
}
}
impl<T> InputVisitable for ConstantRule<T> {
fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {}
}

View File

@ -0,0 +1,12 @@
[package]
name = "compute_graph_macros"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[dependencies]
syn = "2"
quote = "1"
proc-macro2 = "1"

View File

@ -0,0 +1,134 @@
use proc_macro::TokenStream;
use proc_macro2::Literal;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument,
PathArguments, Type,
};
extern crate proc_macro;
/// Derive an implementation of the `InputVisitable` trait and helper methods.
///
/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that
/// calls `visit` on each field of the struct that is of type `Input<T>` for any T.
///
/// The macro also generates helper methods for accessing the value of each input less verbosely.
/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc.
/// For named fields, the generated method name matches the field name. In both cases, the method
/// returns a reference to the input value. As with the `Input::value` method, calling the helper methods
/// before the referenced node has been evaluated is forbidden.
#[proc_macro_derive(InputVisitable)]
pub fn derive_rule(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
if let Data::Struct(ref data) = input.data {
derive_rule_struct(&input, data)
} else {
TokenStream::from(
syn::Error::new(input.ident.span(), "Only structs can derive `Rule`")
.to_compile_error(),
)
}
}
fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
let name = &input.ident;
let visit_inputs = match data.fields {
Fields::Named(ref named) => named
.named
.iter()
.filter(|field| input_value_type(field).is_some())
.map(|field| {
let ident = field.ident.as_ref().unwrap();
quote!(visitor.visit(&self.#ident);)
})
.collect::<Vec<_>>(),
Fields::Unnamed(ref unnamed) => unnamed
.unnamed
.iter()
.enumerate()
.filter(|(_, field)| input_value_type(field).is_some())
.map(|(i, _)| {
let idx_lit = Literal::usize_unsuffixed(i);
quote!(visitor.visit(&self.#idx_lit);)
})
.collect::<Vec<_>>(),
Fields::Unit => vec![],
};
let input_value_methods = match data.fields {
Fields::Named(ref named) => named
.named
.iter()
.filter_map(|field| input_value_type(field).map(|ty| (field, ty)))
.map(|(field, ty)| {
let ident = field.ident.as_ref().unwrap();
quote!(
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
self.#ident.value()
}
)
})
.collect::<Vec<_>>(),
Fields::Unnamed(ref unnamed) => unnamed
.unnamed
.iter()
.enumerate()
.filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty)))
.map(|(i, ty)| {
let idx_lit = Literal::usize_unsuffixed(i);
let ident = format_ident!("input_{i}");
quote!(
fn #ident(&self) -> impl ::std::ops::Deref<Target = #ty> + '_ {
self.#idx_lit.value()
}
)
})
.collect::<Vec<_>>(),
Fields::Unit => vec![],
};
TokenStream::from(quote!(
impl ::compute_graph::rule::InputVisitable for #name {
fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) {
#(#visit_inputs)*
}
}
impl #name {
#(#input_value_methods)*
}
))
}
fn input_value_type(field: &Field) -> Option<&Type> {
if let Type::Path(ref path) = field.ty {
let last_segment = path.path.segments.last().unwrap();
if last_segment.ident == "Input" {
if let PathArguments::AngleBracketed(ref args) = last_segment.arguments {
if args.args.len() == 1 {
if let GenericArgument::Type(ref ty) = args.args.first().unwrap() {
Some(ty)
} else {
None
}
} else {
None
}
} else {
None
}
} else {
None
}
} else {
None
}
}

View File

@ -0,0 +1,7 @@
[package]
name = "derive_test"
version = "0.1.0"
edition = "2021"
[dependencies]
compute_graph = { path = "../compute_graph" }

View File

@ -0,0 +1,52 @@
use compute_graph::rule::{Input, InputVisitable, Rule};
#[derive(InputVisitable)]
struct Add(Input<i32>, Input<i32>, i32);
impl Rule for Add {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
*self.input_0() + *self.input_1() + self.2
}
}
#[derive(InputVisitable)]
struct Add2 {
a: Input<i32>,
b: Input<i32>,
c: i32,
}
impl Rule for Add2 {
type Output = i32;
fn evaluate(&mut self) -> Self::Output {
*self.a() + *self.b() + self.c
}
}
#[cfg(test)]
mod tests {
use compute_graph::builder::GraphBuilder;
use super::*;
#[test]
fn test_add() {
let mut builder = GraphBuilder::new();
let a = builder.add_value(1);
let b = builder.add_value(2);
builder.set_output(Add(a, b, 3));
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 6);
}
#[test]
fn test_add2() {
let mut builder = GraphBuilder::new();
let a = builder.add_value(1);
let b = builder.add_value(2);
builder.set_output(Add2 { a, b, c: 3 });
let mut graph = builder.build().unwrap();
assert_eq!(*graph.evaluate(), 6);
}
}