Fix derive macro with generics

This commit is contained in:
Shadowfacts 2024-11-03 14:57:21 -05:00
parent e69014d98d
commit b8ad929d0b
2 changed files with 33 additions and 5 deletions

View File

@ -2,7 +2,7 @@ use proc_macro::TokenStream;
use proc_macro2::Literal;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument,
parse_macro_input, Data, DataStruct, DeriveInput, Field, Fields, GenericArgument, GenericParam,
PathArguments, Type,
};
@ -25,7 +25,10 @@ pub fn derive_rule(input: TokenStream) -> TokenStream {
derive_rule_struct(&input, data)
} else {
TokenStream::from(
syn::Error::new(input.ident.span(), "Only structs can derive `Rule`")
syn::Error::new(
input.ident.span(),
"Only structs can derive `InputVisitable`",
)
.to_compile_error(),
)
}
@ -33,6 +36,21 @@ pub fn derive_rule(input: TokenStream) -> TokenStream {
fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
let name = &input.ident;
let lt = &input.generics.lt_token;
let params = &input.generics.params;
let gt = &input.generics.gt_token;
let generics = quote!(#lt #params #gt);
let where_clause = &input.generics.where_clause;
let params_only_names = params.iter().map(|p| match p {
GenericParam::Lifetime(_) => {
panic!("Lifetime generics aren't supported when deriving `InputVisitable`")
}
GenericParam::Type(ty) => &ty.ident,
GenericParam::Const(_) => {
panic!("Const generics aren't supported when deriving `InputVisitable`")
}
});
let generics_only_names = quote!(#lt #(#params_only_names),* #gt);
let visit_inputs = match data.fields {
Fields::Named(ref named) => named
@ -95,13 +113,13 @@ fn derive_rule_struct(input: &DeriveInput, data: &DataStruct) -> TokenStream {
TokenStream::from(quote!(
impl ::compute_graph::rule::InputVisitable for #name {
impl #generics ::compute_graph::rule::InputVisitable for #name #generics_only_names #where_clause {
fn visit_inputs(&self, visitor: &mut impl ::compute_graph::rule::InputVisitor) {
#(#visit_inputs)*
}
}
impl #name {
impl #generics #name #generics_only_names #where_clause {
#(#input_value_methods)*
}

View File

@ -1,3 +1,4 @@
use compute_graph::node::NodeValue;
use compute_graph::rule::{Input, InputVisitable, Rule};
#[derive(InputVisitable)]
@ -24,6 +25,15 @@ impl Rule for Add2 {
}
}
#[derive(InputVisitable)]
struct Passthrough<T: NodeValue + Clone>(Input<T>);
impl<T: NodeValue + Clone> Rule for Passthrough<T> {
type Output = T;
fn evaluate(&mut self) -> Self::Output {
self.input_0().clone()
}
}
#[cfg(test)]
mod tests {
use compute_graph::builder::GraphBuilder;