From 08a4bf87dc85c3e701e188100abcb39a94c4ab3a Mon Sep 17 00:00:00 2001 From: Shadowfacts Date: Sun, 3 Nov 2024 01:22:19 -0400 Subject: [PATCH] Derive macro --- Cargo.lock | 17 ++++ Cargo.toml | 2 +- crates/compute_graph/Cargo.toml | 1 + crates/compute_graph/src/builder.rs | 18 ++-- crates/compute_graph/src/lib.rs | 56 ++++++----- crates/compute_graph/src/rule.rs | 67 ++++++------- crates/compute_graph_macros/Cargo.toml | 12 +++ crates/compute_graph_macros/src/lib.rs | 134 +++++++++++++++++++++++++ crates/derive_test/Cargo.toml | 7 ++ crates/derive_test/src/lib.rs | 52 ++++++++++ 10 files changed, 295 insertions(+), 71 deletions(-) create mode 100644 crates/compute_graph_macros/Cargo.toml create mode 100644 crates/compute_graph_macros/src/lib.rs create mode 100644 crates/derive_test/Cargo.toml create mode 100644 crates/derive_test/src/lib.rs diff --git a/Cargo.lock b/Cargo.lock index 6a2152e..c2a5d8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -360,10 +360,20 @@ dependencies = [ name = "compute_graph" version = "0.1.0" dependencies = [ + "compute_graph_macros", "petgraph", "tokio", ] +[[package]] +name = "compute_graph_macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.85", +] + [[package]] name = "core-foundation" version = "0.9.3" @@ -521,6 +531,13 @@ dependencies = [ "syn 1.0.99", ] +[[package]] +name = "derive_test" +version = "0.1.0" +dependencies = [ + "compute_graph", +] + [[package]] name = "digest" version = "0.10.3" diff --git a/Cargo.toml b/Cargo.toml index 8a63e0d..ceaaea9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,4 +1,4 @@ -workspace = { members = ["crates/compute_graph"] } +workspace = { members = ["crates/compute_graph", "crates/compute_graph_macros", "crates/derive_test"] } [package] name = "v6" diff --git a/crates/compute_graph/Cargo.toml b/crates/compute_graph/Cargo.toml index dd289ca..f6fd4fb 100644 --- a/crates/compute_graph/Cargo.toml +++ b/crates/compute_graph/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +compute_graph_macros = { path = "../compute_graph_macros" } petgraph = "0.6.5" [dev-dependencies] diff --git a/crates/compute_graph/src/builder.rs b/crates/compute_graph/src/builder.rs index 0273186..92253ec 100644 --- a/crates/compute_graph/src/builder.rs +++ b/crates/compute_graph/src/builder.rs @@ -92,17 +92,15 @@ impl GraphBuilder { /// the value of the node can be replaced, invalidating the node in the process. /// /// ```rust - /// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}}; + /// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}}; /// let mut builder = GraphBuilder::new(); /// let (input, signal) = builder.add_invalidatable_value(0); + /// # #[derive(InputVisitable)] /// # struct Double(Input); /// # impl Rule for Double { /// # type Output = i32; - /// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) { - /// # visitor.visit(&self.0); - /// # } /// # fn evaluate(&mut self) -> i32 { - /// # *self.0.value() * 2 + /// # *self.input_0() * 2 /// # } /// # } /// builder.set_output(Double(input)); @@ -139,26 +137,24 @@ impl GraphBuilder { /// as well as an [`InvalidationSignal`] which can be used to indicate that the node has been invalidated. /// /// ```rust - /// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}}; + /// # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}}; /// let mut builder = GraphBuilder::new(); + /// # #[derive(InputVisitable)] /// # struct IncrementAfterEvaluate(i32); /// # impl Rule for IncrementAfterEvaluate { /// # type Output = i32; - /// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) {} /// # fn evaluate(&mut self) -> i32 { /// # let result = self.0; /// # self.0 += 1; /// # result /// # } /// # } + /// # #[derive(InputVisitable)] /// # struct Double(Input); /// # impl Rule for Double { /// # type Output = i32; - /// # fn visit_inputs(&self, visitor: &mut impl InputVisitor) { - /// # visitor.visit(&self.0); - /// # } /// # fn evaluate(&mut self) -> i32 { - /// # *self.0.value() * 2 + /// # *self.input_0() * 2 /// # } /// # } /// let (input, signal) = builder.add_invalidatable_rule(IncrementAfterEvaluate(1)); diff --git a/crates/compute_graph/src/lib.rs b/crates/compute_graph/src/lib.rs index 3e0d946..5933583 100644 --- a/crates/compute_graph/src/lib.rs +++ b/crates/compute_graph/src/lib.rs @@ -7,19 +7,16 @@ //! dependencies. For example, an arithmetic operation can be implemented like so: //! //! ```rust -//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}}; +//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}}; //! let mut builder = GraphBuilder::new(); //! let a = builder.add_value(1); //! let b = builder.add_value(2); +//! # #[derive(InputVisitable)] //! # struct Add(Input, Input); //! # impl Rule for Add { //! # type Output = i32; -//! # fn visit_inputs(&self, visitor: &mut impl InputVisitor) { -//! # visitor.visit(&self.0); -//! # visitor.visit(&self.1); -//! # } //! # fn evaluate(&mut self) -> i32 { -//! # *self.0.value() + *self.1.value() +//! # *self.input_0() + *self.input_1() //! # } //! # } //! builder.set_output(Add(a, b)); @@ -33,17 +30,14 @@ //! The `Add` rule is implemented as follows: //! //! ```rust -//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitor}}; +//! # use compute_graph::{builder::GraphBuilder, rule::{Rule, Input, InputVisitable}}; +//! #[derive(InputVisitable)] //! struct Add(Input, Input); //! //! impl Rule for Add { //! type Output = i32; -//! fn visit_inputs(&self, visitor: &mut impl InputVisitor) { -//! visitor.visit(&self.0); -//! visitor.visit(&self.1); -//! } //! fn evaluate(&mut self) -> i32 { -//! *self.0.value() + *self.1.value() +//! *self.input_0() + *self.input_1() //! } //! } //! ``` @@ -373,7 +367,7 @@ impl Clone for ValueInvalidationSignal { #[cfg(test)] mod tests { use super::*; - use crate::rule::ConstantRule; + use crate::rule::{ConstantRule, InputVisitable}; #[test] fn rule_output_with_no_inputs() { @@ -393,11 +387,13 @@ mod tests { } struct Double(Input); - impl Rule for Double { - type Output = i32; + impl InputVisitable for Double { fn visit_inputs(&self, visitor: &mut impl InputVisitor) { visitor.visit(&self.0); } + } + impl Rule for Double { + type Output = i32; fn evaluate(&mut self) -> i32 { *self.0.value() * 2 } @@ -421,9 +417,11 @@ mod tests { } struct Inc(i32); + impl InputVisitable for Inc { + fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} + } impl Rule for Inc { type Output = i32; - fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} fn evaluate(&mut self) -> i32 { self.0 += 1; return self.0; @@ -445,12 +443,14 @@ mod tests { } struct Add(Input, Input); - impl Rule for Add { - type Output = i32; + impl InputVisitable for Add { fn visit_inputs(&self, visitor: &mut impl InputVisitor) { visitor.visit(&self.0); visitor.visit(&self.1); } + } + impl Rule for Add { + type Output = i32; fn evaluate(&mut self) -> i32 { *self.0.value() + *self.1.value() } @@ -489,13 +489,15 @@ mod tests { } struct DeferredInput(Rc>>>); - impl Rule for DeferredInput { - type Output = i32; + impl InputVisitable for DeferredInput { fn visit_inputs(&self, visitor: &mut impl InputVisitor) { let borrowed = self.0.borrow(); let input = borrowed.as_ref().unwrap(); visitor.visit(input); } + } + impl Rule for DeferredInput { + type Output = i32; fn evaluate(&mut self) -> i32 { *self.0.borrow().as_ref().unwrap().value() } @@ -560,9 +562,11 @@ mod tests { #[tokio::test] async fn async_rule() { struct AsyncConst(i32); + impl InputVisitable for AsyncConst { + fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} + } impl AsyncRule for AsyncConst { type Output = i32; - fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} async fn evaluate(&mut self) -> i32 { self.0 } @@ -578,9 +582,11 @@ mod tests { #[derive(PartialEq, Debug)] struct NonCloneable; struct Output; + impl InputVisitable for Output { + fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} + } impl Rule for Output { type Output = NonCloneable; - fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} fn evaluate(&mut self) -> Self::Output { NonCloneable } @@ -596,11 +602,13 @@ mod tests { let mut builder = GraphBuilder::new(); let (a, invalidate) = builder.add_invalidatable_rule(ConstantRule::new(0)); struct IncAdd(Input, i32); - impl Rule for IncAdd { - type Output = i32; + impl InputVisitable for IncAdd { fn visit_inputs(&self, visitor: &mut impl InputVisitor) { visitor.visit(&self.0); } + } + impl Rule for IncAdd { + type Output = i32; fn evaluate(&mut self) -> Self::Output { self.1 += 1; *self.0.value() + self.1 diff --git a/crates/compute_graph/src/rule.rs b/crates/compute_graph/src/rule.rs index 73fe8a5..45377b1 100644 --- a/crates/compute_graph/src/rule.rs +++ b/crates/compute_graph/src/rule.rs @@ -1,5 +1,6 @@ use crate::node::NodeValue; use crate::NodeId; +pub use compute_graph_macros::InputVisitable; use std::cell::{Ref, RefCell}; use std::ops::Deref; use std::rc::Rc; @@ -9,38 +10,22 @@ use std::rc::Rc; /// A rule for addition could be implemented like so: /// /// ```rust -/// # use compute_graph::rule::{Rule, Input, InputVisitor}; +/// # use compute_graph::rule::{Rule, Input, InputVisitable}; +/// #[derive(InputVisitable)] /// struct Add(Input, Input); /// /// impl Rule for Add { /// type Output = i32; /// -/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) { -/// visitor.visit(&self.0); -/// visitor.visit(&self.1); -/// } -/// /// fn evaluate(&mut self) -> Self::Output { -/// *self.0.value() + *self.1.value() +/// *self.input_0() + *self.input_1() /// } /// } /// ``` -pub trait Rule: 'static { +pub trait Rule: InputVisitable + 'static { /// The type of the output value of the rule. type Output: NodeValue; - /// Visits all the [`Input`]s of this rule. - /// - /// This method is called when the graph is built/modified in order to establish edges of the graph, - /// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is - /// considered a dependency of the rule's node. - /// - /// While it is permitted for the dependencies of a rule to change after it has been added to the graph, - /// doing so only permitted before the graph has been built or during the callback of - /// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will - /// not be detected and will not be represented in the graph. - fn visit_inputs(&self, visitor: &mut impl InputVisitor); - /// Produces the value of this rule using its inputs. /// /// Note that the receiver of this method is a mutable reference to the rule itself. Rules are permitted @@ -57,37 +42,46 @@ pub trait Rule: 'static { /// A rule produces a value for a graph node asynchronously. /// /// ```rust -/// # use compute_graph::rule::{AsyncRule, Input, InputVisitor}; +/// # use compute_graph::rule::{AsyncRule, Input, InputVisitable}; /// # async fn do_async_work(_: i32) -> i32 { 0 } +/// #[derive(InputVisitable)] /// struct AsyncMath(Input); /// /// impl AsyncRule for AsyncMath { /// type Output = i32; /// -/// fn visit_inputs(&self, visitor: &mut impl InputVisitor) { -/// visitor.visit(&self.0); -/// } -/// /// async fn evaluate(&mut self) -> Self::Output { -/// do_async_work(*self.0.value()).await +/// do_async_work(*self.input_0()).await /// } /// } /// ``` -pub trait AsyncRule: 'static { +pub trait AsyncRule: InputVisitable + 'static { /// The type of the output value of the rule. type Output: NodeValue; - /// Visits all the [`Input`]s of this rule. - /// - /// See [`Rule::visit_inputs`] for additional details; the same caveats apply. - fn visit_inputs(&self, visitor: &mut impl InputVisitor); - /// Asynchronously produces the value of this rule using its inputs. /// /// See [`Rule::evaluate`] for additional details; the same considerations apply. async fn evaluate(&mut self) -> Self::Output; } +/// Common supertrait of [`Rule`] and [`AsyncRule`] that defines how rule inputs are visited. +/// +/// The implementation of this trait can generally be derived using [`derive@InputVisitable`]. +pub trait InputVisitable { + /// Visits all the [`Input`]s of this rule. + /// + /// This method is called when the graph is built/modified in order to establish edges of the graph, + /// representing the dependencies. Any input that the [`InputVisitor::visit`] is called with is + /// considered a dependency of the rule's node. + /// + /// While it is permitted for the dependencies of a rule to change after it has been added to the graph, + /// doing so only permitted before the graph has been built or during the callback of + /// [`Graph::modify`](`crate::Graph::modify`). Changes to the rule's dependencies outside of that will + /// not be detected and will not be represented in the graph. + fn visit_inputs(&self, visitor: &mut impl InputVisitor); +} + /// A placeholder for the output of one node to be used as an input for another. /// /// To obtain an input, add a value or rule to a [`GraphBuilder`](`crate::builder::GraphBuilder`). @@ -125,7 +119,8 @@ impl Clone for Input { /// A type that can visit arbitrary [`Input`]s. /// -/// You generally do not implement this trait yourself. An implementation is provided to [`Rule::visit_inputs`]. +/// You generally do not implement this trait yourself. An implementation is provided to +/// [`InputVisitable::visit_inputs`]. pub trait InputVisitor { /// Visit an input whose value is of type `T`. fn visit(&mut self, input: &Input); @@ -146,9 +141,11 @@ impl ConstantRule { impl Rule for ConstantRule { type Output = T; - fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} - fn evaluate(&mut self) -> Self::Output { self.0.clone() } } + +impl InputVisitable for ConstantRule { + fn visit_inputs(&self, _visitor: &mut impl InputVisitor) {} +} diff --git a/crates/compute_graph_macros/Cargo.toml b/crates/compute_graph_macros/Cargo.toml new file mode 100644 index 0000000..e96fc19 --- /dev/null +++ b/crates/compute_graph_macros/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "compute_graph_macros" +version = "0.1.0" +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +syn = "2" +quote = "1" +proc-macro2 = "1" diff --git a/crates/compute_graph_macros/src/lib.rs b/crates/compute_graph_macros/src/lib.rs new file mode 100644 index 0000000..cea01fe --- /dev/null +++ b/crates/compute_graph_macros/src/lib.rs @@ -0,0 +1,134 @@ +use proc_macro::TokenStream; +use proc_macro2::Literal; +use quote::{format_ident, quote}; +use syn::{ + parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, + PathArguments, Type, +}; + +extern crate proc_macro; + +/// Derive an implementation of the `InputVisitable` trait and helper methods. +/// +/// This macro generates an implementation of the `InputVisitable` trait and the `visit_input` method that +/// calls `visit` on each field of the struct that is of type `Input` for any T. +/// +/// The macro also generates helper methods for accessing the value of each input less verbosely. +/// For unnamed struct fields, the methods generated have the form `input_0`, `input_1`, etc. +/// For named fields, the generated method name matches the field name. In both cases, the method +/// returns a reference to the input value. As with the `Input::value` method, calling the helper methods +/// before the referenced node has been evaluated is forbidden. +#[proc_macro_derive(InputVisitable)] +pub fn derive_rule(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as DeriveInput); + if let Data::Struct(ref data) = input.data { + derive_rule_struct(&input, data) + } else { + TokenStream::from( + syn::Error::new(input.ident.span(), "Only structs can derive `Rule`") + .to_compile_error(), + ) + } +} + +fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream { + let name = &input.ident; + + let visit_inputs = match data.fields { + Fields::Named(ref named) => named + .named + .iter() + .filter(|field| input_value_type(field).is_some()) + .map(|field| { + let ident = field.ident.as_ref().unwrap(); + quote!(visitor.visit(&self.#ident);) + }) + .collect::>(), + Fields::Unnamed(ref unnamed) => unnamed + .unnamed + .iter() + .enumerate() + .filter(|(_, field)| input_value_type(field).is_some()) + .map(|(i, _)| { + let idx_lit = Literal::usize_unsuffixed(i); + quote!(visitor.visit(&self.#idx_lit);) + }) + .collect::>(), + Fields::Unit => vec![], + }; + + let input_value_methods = match data.fields { + Fields::Named(ref named) => named + .named + .iter() + .filter_map(|field| input_value_type(field).map(|ty| (field, ty))) + .map(|(field, ty)| { + let ident = field.ident.as_ref().unwrap(); + quote!( + + fn #ident(&self) -> impl ::std::ops::Deref + '_ { + self.#ident.value() + } + + ) + }) + .collect::>(), + Fields::Unnamed(ref unnamed) => unnamed + .unnamed + .iter() + .enumerate() + .filter_map(|(i, field)| input_value_type(field).map(|ty| (i, ty))) + .map(|(i, ty)| { + let idx_lit = Literal::usize_unsuffixed(i); + let ident = format_ident!("input_{i}"); + quote!( + + fn #ident(&self) -> impl ::std::ops::Deref + '_ { + self.#idx_lit.value() + } + + ) + }) + .collect::>(), + Fields::Unit => vec![], + }; + + TokenStream::from(quote!( + + impl ::compute_graph::rule::InputVisitable for #name { + fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) { + #(#visit_inputs)* + } + } + + impl #name { + #(#input_value_methods)* + } + + )) +} + +fn input_value_type(field: &Field) -> Option<&Type> { + if let Type::Path(ref path) = field.ty { + let last_segment = path.path.segments.last().unwrap(); + if last_segment.ident == "Input" { + if let PathArguments::AngleBracketed(ref args) = last_segment.arguments { + if args.args.len() == 1 { + if let GenericArgument::Type(ref ty) = args.args.first().unwrap() { + Some(ty) + } else { + None + } + } else { + None + } + } else { + None + } + } else { + None + } + } else { + None + } +} diff --git a/crates/derive_test/Cargo.toml b/crates/derive_test/Cargo.toml new file mode 100644 index 0000000..59909cb --- /dev/null +++ b/crates/derive_test/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "derive_test" +version = "0.1.0" +edition = "2021" + +[dependencies] +compute_graph = { path = "../compute_graph" } diff --git a/crates/derive_test/src/lib.rs b/crates/derive_test/src/lib.rs new file mode 100644 index 0000000..73be8a1 --- /dev/null +++ b/crates/derive_test/src/lib.rs @@ -0,0 +1,52 @@ +use compute_graph::rule::{Input, InputVisitable, Rule}; + +#[derive(InputVisitable)] +struct Add(Input, Input, i32); + +impl Rule for Add { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + *self.input_0() + *self.input_1() + self.2 + } +} + +#[derive(InputVisitable)] +struct Add2 { + a: Input, + b: Input, + c: i32, +} + +impl Rule for Add2 { + type Output = i32; + fn evaluate(&mut self) -> Self::Output { + *self.a() + *self.b() + self.c + } +} + +#[cfg(test)] +mod tests { + use compute_graph::builder::GraphBuilder; + + use super::*; + + #[test] + fn test_add() { + let mut builder = GraphBuilder::new(); + let a = builder.add_value(1); + let b = builder.add_value(2); + builder.set_output(Add(a, b, 3)); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate(), 6); + } + + #[test] + fn test_add2() { + let mut builder = GraphBuilder::new(); + let a = builder.add_value(1); + let b = builder.add_value(2); + builder.set_output(Add2 { a, b, c: 3 }); + let mut graph = builder.build().unwrap(); + assert_eq!(*graph.evaluate(), 6); + } +}