Use associated types for rule outputs

This commit is contained in:
Shadowfacts 2024-10-31 11:19:16 -04:00
parent d8f2a393ba
commit ca0b77349a

View File

@ -107,7 +107,7 @@ impl<Output: Clone + 'static> Graph<Output, Asynchronous> {
} }
impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> { impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
pub fn set_output<R: Rule<O> + 'static>(&mut self, rule: R) { pub fn set_output<R: Rule<Output = O>>(&mut self, rule: R) {
let input = self.add_rule(rule); let input = self.add_rule(rule);
self.output = Some(input.node_idx); self.output = Some(input.node_idx);
} }
@ -126,14 +126,18 @@ impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
return self.add_node(ConstNode::new(value.clone())); return self.add_node(ConstNode::new(value.clone()));
} }
pub fn add_rule<R: Rule<V> + 'static, V: Clone + 'static>(&mut self, rule: R) -> Input<V> { pub fn add_rule<R>(&mut self, rule: R) -> Input<R::Output>
where
R: Rule,
R::Output: Clone,
{
return self.add_node(RuleNode::new(rule)); return self.add_node(RuleNode::new(rule));
} }
pub fn add_invalidatable_rule<R, V, F>(&mut self, mut f: F) -> Input<V> pub fn add_invalidatable_rule<R, F>(&mut self, mut f: F) -> Input<R::Output>
where where
R: Rule<V> + 'static, R: Rule,
V: Clone + 'static, R::Output: Clone,
F: FnMut(InvalidationSignal<S>) -> R, F: FnMut(InvalidationSignal<S>) -> R,
{ {
let node_idx = Rc::new(Cell::new(None)); let node_idx = Rc::new(Cell::new(None));
@ -189,22 +193,23 @@ impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
} }
impl<O: Clone + 'static> Graph<O, Asynchronous> { impl<O: Clone + 'static> Graph<O, Asynchronous> {
pub fn set_async_output<R: AsyncRule<O> + 'static>(&mut self, rule: R) { pub fn set_async_output<R: AsyncRule<Output = O>>(&mut self, rule: R) {
let input = self.add_async_rule(rule); let input = self.add_async_rule(rule);
self.output = Some(input.node_idx); self.output = Some(input.node_idx);
} }
pub fn add_async_rule<R: AsyncRule<V> + 'static, V: Clone + 'static>( pub fn add_async_rule<R>(&mut self, rule: R) -> Input<R::Output>
&mut self, where
rule: R, R: AsyncRule,
) -> Input<V> { R::Output: Clone,
{
self.add_node(AsyncRuleNode::new(rule)) self.add_node(AsyncRuleNode::new(rule))
} }
pub fn add_invalidatable_async_rule<R, V, F>(&mut self, mut f: F) -> Input<V> pub fn add_invalidatable_async_rule<R, F>(&mut self, mut f: F) -> Input<R::Output>
where where
R: AsyncRule<V> + 'static, R: AsyncRule,
V: Clone + 'static, R::Output: Clone,
F: FnMut(InvalidationSignal<Asynchronous>) -> R, F: FnMut(InvalidationSignal<Asynchronous>) -> R,
{ {
let node_idx = Rc::new(Cell::new(None)); let node_idx = Rc::new(Cell::new(None));
@ -494,7 +499,7 @@ struct RuleNode<R, V, S> {
synchronicity: std::marker::PhantomData<S>, synchronicity: std::marker::PhantomData<S>,
} }
impl<R: Rule<V>, V, S> RuleNode<R, V, S> { impl<R: Rule, S> RuleNode<R, R::Output, S> {
fn new(rule: R) -> Self { fn new(rule: R) -> Self {
Self { Self {
rule, rule,
@ -505,7 +510,7 @@ impl<R: Rule<V>, V, S> RuleNode<R, V, S> {
} }
} }
impl<R: Rule<V> + 'static, V: Clone + 'static, S: Synchronicity> Node<V, S> for RuleNode<R, V, S> { impl<R: Rule, S: Synchronicity> Node<R::Output, S> for RuleNode<R, R::Output, S> {
fn is_valid(&self) -> bool { fn is_valid(&self) -> bool {
self.valid self.valid
} }
@ -531,7 +536,7 @@ impl<R: Rule<V> + 'static, V: Clone + 'static, S: Synchronicity> Node<V, S> for
S::make_update_result() S::make_update_result()
} }
fn value_rc(&self) -> Rc<RefCell<Option<V>>> { fn value_rc(&self) -> Rc<RefCell<Option<R::Output>>> {
Rc::clone(&self.value) Rc::clone(&self.value)
} }
} }
@ -542,7 +547,7 @@ struct AsyncRuleNode<R, V> {
valid: bool, valid: bool,
} }
impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> { impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
fn new(rule: R) -> Self { fn new(rule: R) -> Self {
Self { Self {
rule, rule,
@ -552,7 +557,10 @@ impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
} }
} }
impl<R: AsyncRule<V> + 'static, V: Clone + 'static> Node<V, Asynchronous> for AsyncRuleNode<R, V> { impl<R: AsyncRule> Node<R::Output, Asynchronous> for AsyncRuleNode<R, R::Output>
where
R::Output: Clone,
{
fn is_valid(&self) -> bool { fn is_valid(&self) -> bool {
self.valid self.valid
} }
@ -575,12 +583,12 @@ impl<R: AsyncRule<V> + 'static, V: Clone + 'static> Node<V, Asynchronous> for As
Box::pin(self.do_update()) Box::pin(self.do_update())
} }
fn value_rc(&self) -> Rc<RefCell<Option<V>>> { fn value_rc(&self) -> Rc<RefCell<Option<R::Output>>> {
Rc::clone(&self.value) Rc::clone(&self.value)
} }
} }
impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> { impl<R: AsyncRule> AsyncRuleNode<R, R::Output> {
async fn do_update(&mut self) { async fn do_update(&mut self) {
let new_value = self.rule.evaluate().await; let new_value = self.rule.evaluate().await;
self.valid = true; self.valid = true;
@ -588,16 +596,20 @@ impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
} }
} }
pub trait Rule<Output> { pub trait Rule: 'static {
type Output;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor); fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
fn evaluate(&mut self) -> Output; fn evaluate(&mut self) -> Self::Output;
} }
pub trait AsyncRule<Output> { pub trait AsyncRule: 'static {
type Output: 'static;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor); fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
async fn evaluate(&mut self) -> Output; async fn evaluate(&mut self) -> Self::Output;
} }
pub trait InputVisitor { pub trait InputVisitor {
@ -616,7 +628,8 @@ mod tests {
} }
struct ConstantRule(i32); struct ConstantRule(i32);
impl Rule<i32> for ConstantRule { impl Rule for ConstantRule {
type Output = i32;
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {} fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> i32 { fn evaluate(&mut self) -> i32 {
self.0 self.0
@ -641,7 +654,8 @@ mod tests {
} }
struct Double(Input<i32>); struct Double(Input<i32>);
impl Rule<i32> for Double { impl Rule for Double {
type Output = i32;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) { fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
visitor.visit(&mut self.0); visitor.visit(&mut self.0);
} }
@ -667,7 +681,8 @@ mod tests {
assert_eq!(graph.freeze().unwrap().evaluate(), 168); assert_eq!(graph.freeze().unwrap().evaluate(), 168);
} }
struct Inc(i32); struct Inc(i32);
impl Rule<i32> for Inc { impl Rule for Inc {
type Output = i32;
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {} fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
fn evaluate(&mut self) -> i32 { fn evaluate(&mut self) -> i32 {
self.0 += 1; self.0 += 1;
@ -694,7 +709,8 @@ mod tests {
} }
struct Add(Input<i32>, Input<i32>); struct Add(Input<i32>, Input<i32>);
impl Rule<i32> for Add { impl Rule for Add {
type Output = i32;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) { fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
visitor.visit(&mut self.0); visitor.visit(&mut self.0);
visitor.visit(&mut self.1); visitor.visit(&mut self.1);
@ -741,7 +757,8 @@ mod tests {
} }
struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>); struct DeferredInput(Rc<RefCell<Option<Input<i32>>>>);
impl Rule<i32> for DeferredInput { impl Rule for DeferredInput {
type Output = i32;
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) { fn visit_inputs(&mut self, visitor: &mut impl InputVisitor) {
let mut borrowed = self.0.borrow_mut(); let mut borrowed = self.0.borrow_mut();
let input = borrowed.as_mut().unwrap(); let input = borrowed.as_mut().unwrap();
@ -811,7 +828,8 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn async_rule() { async fn async_rule() {
struct AsyncConst(i32); struct AsyncConst(i32);
impl AsyncRule<i32> for AsyncConst { impl AsyncRule for AsyncConst {
type Output = i32;
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {} fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
async fn evaluate(&mut self) -> i32 { async fn evaluate(&mut self) -> i32 {
self.0 self.0