Make the graph generic over whether it's sync/async
This commit is contained in:
parent
81cd986f77
commit
1530933464
@ -7,3 +7,6 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
petgraph = "0.6.5"
|
petgraph = "0.6.5"
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio = { version = "1.41.0", features = ["rt", "macros"] }
|
||||||
|
@ -1,24 +1,117 @@
|
|||||||
#![feature(let_chains)]
|
#![feature(let_chains)]
|
||||||
|
#![feature(async_closure)]
|
||||||
|
|
||||||
mod util;
|
mod util;
|
||||||
|
|
||||||
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
|
use petgraph::visit::{IntoEdgeReferences, NodeIndexable};
|
||||||
use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
|
use petgraph::{graph::NodeIndex, stable_graph::StableDiGraph, visit::EdgeRef};
|
||||||
use std::cell::{Cell, RefCell};
|
use std::any::Any;
|
||||||
|
use std::cell::{Cell, Ref, RefCell, RefMut};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::ops::{Deref, DerefMut};
|
||||||
|
use std::pin::Pin;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
use std::{any::Any, collections::VecDeque};
|
|
||||||
|
|
||||||
type NodeGraph = StableDiGraph<ErasedNode, (), u32>;
|
// TODO: consider using a struct for this, because generic bounds of type aliases aren't enforced
|
||||||
|
type NodeGraph<S: Synchronicity> = StableDiGraph<ErasedNode<S>, (), u32>;
|
||||||
|
|
||||||
pub struct Graph<Output> {
|
pub trait Synchronicity: 'static {
|
||||||
// we treat this as a StableGraph, since nodes are never removed
|
type AnyStorage;
|
||||||
node_graph: Rc<RefCell<NodeGraph>>,
|
fn make_any_storage<T: Any>(value: T) -> Self::AnyStorage;
|
||||||
|
fn unbox_any_storage<T: 'static>(storage: &Self::AnyStorage) -> impl Deref<Target = T>;
|
||||||
|
fn unbox_any_storage_mut<T: 'static>(
|
||||||
|
storage: &mut Self::AnyStorage,
|
||||||
|
) -> impl DerefMut<Target = T>;
|
||||||
|
|
||||||
|
type UpdateFn;
|
||||||
|
fn make_update_fn<V: 'static>() -> Self::UpdateFn;
|
||||||
|
|
||||||
|
type UpdateResult<'a>;
|
||||||
|
fn make_update_result<'a>() -> Self::UpdateResult<'a>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Synchronous {}
|
||||||
|
|
||||||
|
impl Synchronicity for Synchronous {
|
||||||
|
type AnyStorage = Box<dyn Any>;
|
||||||
|
|
||||||
|
fn make_any_storage<T: Any>(value: T) -> Self::AnyStorage {
|
||||||
|
Box::new(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unbox_any_storage<T: 'static>(storage: &Self::AnyStorage) -> impl Deref<Target = T> {
|
||||||
|
storage.downcast_ref().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unbox_any_storage_mut<T: 'static>(
|
||||||
|
storage: &mut Self::AnyStorage,
|
||||||
|
) -> impl DerefMut<Target = T> {
|
||||||
|
storage.downcast_mut().unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateFn = Box<dyn Fn(&mut Box<dyn Any>) -> ()>;
|
||||||
|
|
||||||
|
fn make_update_fn<V: 'static>() -> Self::UpdateFn {
|
||||||
|
Box::new(|any| {
|
||||||
|
let x = any.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
|
||||||
|
x.update();
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateResult<'a> = ();
|
||||||
|
|
||||||
|
fn make_update_result<'a>() -> Self::UpdateResult<'a> {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub enum Asynchronous {}
|
||||||
|
|
||||||
|
impl Synchronicity for Asynchronous {
|
||||||
|
type AnyStorage = Rc<RefCell<Box<dyn Any>>>;
|
||||||
|
|
||||||
|
fn make_any_storage<T: Any>(value: T) -> Self::AnyStorage {
|
||||||
|
Rc::new(RefCell::new(Box::new(value)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unbox_any_storage<T: 'static>(storage: &Self::AnyStorage) -> impl Deref<Target = T> {
|
||||||
|
Ref::map(storage.borrow(), |any| any.downcast_ref().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unbox_any_storage_mut<T: 'static>(
|
||||||
|
storage: &mut Self::AnyStorage,
|
||||||
|
) -> impl DerefMut<Target = T> {
|
||||||
|
RefMut::map(storage.borrow_mut(), |any| any.downcast_mut().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateFn = Box<dyn Fn(Rc<RefCell<Box<dyn Any>>>) -> Pin<Box<dyn Future<Output = ()>>>>;
|
||||||
|
|
||||||
|
fn make_update_fn<V: 'static>() -> Self::UpdateFn {
|
||||||
|
Box::new(|any| Box::pin(Asynchronous::do_async_update::<V>(any)))
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateResult<'a> = Pin<Box<dyn Future<Output = ()> + 'a>>;
|
||||||
|
|
||||||
|
fn make_update_result<'a>() -> Self::UpdateResult<'a> {
|
||||||
|
Box::pin(std::future::ready(()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Asynchronous {
|
||||||
|
async fn do_async_update<V: 'static>(any: Rc<RefCell<Box<dyn Any>>>) {
|
||||||
|
let mut any_ = any.borrow_mut();
|
||||||
|
let x = any_.downcast_mut::<Box<dyn Node<V, Self>>>().unwrap();
|
||||||
|
x.update().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Graph<Output, Synch: Synchronicity> {
|
||||||
|
node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||||
output: Option<NodeIndex<u32>>,
|
output: Option<NodeIndex<u32>>,
|
||||||
output_type: std::marker::PhantomData<Output>,
|
output_type: std::marker::PhantomData<Output>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Output: Clone + 'static> Graph<Output> {
|
impl<Output: Clone + 'static> Graph<Output, Synchronous> {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
node_graph: Rc::new(RefCell::new(StableDiGraph::new())),
|
node_graph: Rc::new(RefCell::new(StableDiGraph::new())),
|
||||||
@ -26,13 +119,25 @@ impl<Output: Clone + 'static> Graph<Output> {
|
|||||||
output_type: std::marker::PhantomData,
|
output_type: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_output<R: Rule<Output> + 'static>(&mut self, rule: R) {
|
impl<Output: Clone + 'static> Graph<Output, Asynchronous> {
|
||||||
|
pub fn new_async() -> Self {
|
||||||
|
Self {
|
||||||
|
node_graph: Rc::new(RefCell::new(StableDiGraph::new())),
|
||||||
|
output: None,
|
||||||
|
output_type: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<O: Clone + 'static, S: Synchronicity> Graph<O, S> {
|
||||||
|
pub fn set_output<R: Rule<O> + 'static>(&mut self, rule: R) {
|
||||||
let input = self.add_rule(rule);
|
let input = self.add_rule(rule);
|
||||||
self.output = Some(input.node_idx);
|
self.output = Some(input.node_idx);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_node<V: Clone + 'static>(&mut self, node: impl Node<V> + 'static) -> Input<V> {
|
fn add_node<V: Clone + 'static>(&mut self, node: impl Node<V, S> + 'static) -> Input<V> {
|
||||||
let value = node.value_rc();
|
let value = node.value_rc();
|
||||||
let erased = ErasedNode::new(node);
|
let erased = ErasedNode::new(node);
|
||||||
let idx = self.node_graph.borrow_mut().add_node(erased);
|
let idx = self.node_graph.borrow_mut().add_node(erased);
|
||||||
@ -43,7 +148,7 @@ impl<Output: Clone + 'static> Graph<Output> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_value<V: Clone + 'static>(&mut self, value: V) -> Input<V> {
|
pub fn add_value<V: Clone + 'static>(&mut self, value: V) -> Input<V> {
|
||||||
return self.add_node(ConstNode(value.clone()));
|
return self.add_node(ConstNode::new(value.clone()));
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_rule<R: Rule<V> + 'static, V: Clone + 'static>(&mut self, rule: R) -> Input<V> {
|
pub fn add_rule<R: Rule<V> + 'static, V: Clone + 'static>(&mut self, rule: R) -> Input<V> {
|
||||||
@ -54,7 +159,7 @@ impl<Output: Clone + 'static> Graph<Output> {
|
|||||||
where
|
where
|
||||||
R: Rule<V> + 'static,
|
R: Rule<V> + 'static,
|
||||||
V: Clone + 'static,
|
V: Clone + 'static,
|
||||||
F: FnMut(InvalidationSignal) -> R,
|
F: FnMut(InvalidationSignal<S>) -> R,
|
||||||
{
|
{
|
||||||
let node_idx = Rc::new(Cell::new(None));
|
let node_idx = Rc::new(Cell::new(None));
|
||||||
let signal = InvalidationSignal {
|
let signal = InvalidationSignal {
|
||||||
@ -66,7 +171,7 @@ impl<Output: Clone + 'static> Graph<Output> {
|
|||||||
input
|
input
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn freeze(self) -> Result<FrozenGraph<Output>, GraphFreezeError> {
|
pub fn freeze(self) -> Result<FrozenGraph<O, S>, GraphFreezeError> {
|
||||||
let output: NodeIndex<u32> = match self.output {
|
let output: NodeIndex<u32> = match self.output {
|
||||||
None => return Err(GraphFreezeError::NoOutput),
|
None => return Err(GraphFreezeError::NoOutput),
|
||||||
Some(idx) => idx,
|
Some(idx) => idx,
|
||||||
@ -78,7 +183,7 @@ impl<Output: Clone + 'static> Graph<Output> {
|
|||||||
let mut edges = vec![];
|
let mut edges = vec![];
|
||||||
for idx in indices {
|
for idx in indices {
|
||||||
let node = &mut self.node_graph.borrow_mut()[idx];
|
let node = &mut self.node_graph.borrow_mut()[idx];
|
||||||
(node.visit_inputs)(&mut node.any, &mut |input_idx| {
|
node.visit_inputs(&mut |input_idx| {
|
||||||
edges.push((input_idx, idx));
|
edges.push((input_idx, idx));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -108,57 +213,53 @@ impl<Output: Clone + 'static> Graph<Output> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<O: Clone + 'static> Graph<O, Asynchronous> {
|
||||||
|
pub fn set_async_output<R: AsyncRule<O> + 'static>(&mut self, rule: R) {
|
||||||
|
let input = self.add_async_rule(rule);
|
||||||
|
self.output = Some(input.node_idx);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_async_rule<R: AsyncRule<V> + 'static, V: Clone + 'static>(
|
||||||
|
&mut self,
|
||||||
|
rule: R,
|
||||||
|
) -> Input<V> {
|
||||||
|
self.add_node(AsyncRuleNode::new(rule))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_invalidatable_async_rule<R, V, F>(&mut self, mut f: F) -> Input<V>
|
||||||
|
where
|
||||||
|
R: AsyncRule<V> + 'static,
|
||||||
|
V: Clone + 'static,
|
||||||
|
F: FnMut(InvalidationSignal<Asynchronous>) -> R,
|
||||||
|
{
|
||||||
|
let node_idx = Rc::new(Cell::new(None));
|
||||||
|
let signal = InvalidationSignal {
|
||||||
|
node_idx: Rc::clone(&node_idx),
|
||||||
|
graph: Rc::clone(&self.node_graph),
|
||||||
|
};
|
||||||
|
let input = self.add_async_rule(f(signal));
|
||||||
|
node_idx.set(Some(input.node_idx));
|
||||||
|
input
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum GraphFreezeError {
|
pub enum GraphFreezeError {
|
||||||
NoOutput,
|
NoOutput,
|
||||||
Cyclic(Vec<NodeIndex<u32>>),
|
Cyclic(Vec<NodeIndex<u32>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct FrozenGraph<Output> {
|
pub struct FrozenGraph<Output, Synch: Synchronicity> {
|
||||||
node_graph: Rc<RefCell<NodeGraph>>,
|
node_graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||||
output: NodeIndex<u32>,
|
output: NodeIndex<u32>,
|
||||||
output_type: std::marker::PhantomData<Output>,
|
output_type: std::marker::PhantomData<Output>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<Output: Clone + 'static> FrozenGraph<Output> {
|
impl<O: Clone + 'static, S: Synchronicity> FrozenGraph<O, S> {
|
||||||
fn update_node(&mut self, idx: NodeIndex<u32>) {
|
|
||||||
let graph = self.node_graph.borrow();
|
|
||||||
let node = &graph[idx];
|
|
||||||
let is_valid = (node.is_valid)(&node.any);
|
|
||||||
drop(graph);
|
|
||||||
if !is_valid {
|
|
||||||
// collect all the edges into a vec so that we can mutably borrow the graph to update the nodes
|
|
||||||
let edge_sources = self
|
|
||||||
.node_graph
|
|
||||||
.borrow()
|
|
||||||
.edges_directed(idx, petgraph::Direction::Incoming)
|
|
||||||
.map(|edge| edge.source())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// Update the dependencies of this node.
|
|
||||||
// TODO: iterating/recursing here seems less than efficient
|
|
||||||
// instead, in evaluate, topo sort the graph and update invalid nodes?
|
|
||||||
for source in edge_sources {
|
|
||||||
self.update_node(source);
|
|
||||||
}
|
|
||||||
|
|
||||||
let node = &mut self.node_graph.borrow_mut()[idx];
|
|
||||||
// Actually update the node's value.
|
|
||||||
(node.update)(&mut node.any);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_output_valid(&self) -> bool {
|
pub fn is_output_valid(&self) -> bool {
|
||||||
let graph = self.node_graph.borrow();
|
let graph = self.node_graph.borrow();
|
||||||
let node = &graph[self.output];
|
let node = &graph[self.output];
|
||||||
(node.is_valid)(&node.any)
|
node.is_valid()
|
||||||
}
|
|
||||||
|
|
||||||
pub fn evaluate(&mut self) -> Output {
|
|
||||||
self.update_node(self.output);
|
|
||||||
let graph = self.node_graph.borrow();
|
|
||||||
let node = &graph[self.output].expect_type::<Output>();
|
|
||||||
node.value_rc().borrow().clone().unwrap()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn node_count(&self) -> usize {
|
pub fn node_count(&self) -> usize {
|
||||||
@ -167,7 +268,7 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
|
|||||||
|
|
||||||
pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError>
|
pub fn modify<F>(&mut self, mut f: F) -> Result<(), GraphFreezeError>
|
||||||
where
|
where
|
||||||
F: FnMut(&mut Graph<Output>) -> (),
|
F: FnMut(&mut Graph<O, S>) -> (),
|
||||||
{
|
{
|
||||||
// Copy all the current edges so we can check if any change.
|
// Copy all the current edges so we can check if any change.
|
||||||
let graph = self.node_graph.borrow();
|
let graph = self.node_graph.borrow();
|
||||||
@ -206,12 +307,74 @@ impl<Output: Clone + 'static> FrozenGraph<Output> {
|
|||||||
to_invalidate.push_back(edge.target());
|
to_invalidate.push_back(edge.target());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
invalidate_nodes(&mut graph, to_invalidate);
|
invalidate_nodes::<S>(&mut graph, to_invalidate);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<Output: Clone + 'static> FrozenGraph<Output, Synchronous> {
|
||||||
|
fn update_node(&mut self, idx: NodeIndex<u32>) {
|
||||||
|
let graph = self.node_graph.borrow();
|
||||||
|
let node = &graph[idx];
|
||||||
|
if !node.is_valid() {
|
||||||
|
// collect all the edges into a vec so that we can mutably borrow the graph to update the nodes
|
||||||
|
let edge_sources = graph
|
||||||
|
.edges_directed(idx, petgraph::Direction::Incoming)
|
||||||
|
.map(|edge| edge.source())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
drop(graph);
|
||||||
|
|
||||||
|
// Update the dependencies of this node.
|
||||||
|
// TODO: iterating/recursing here seems less than efficient
|
||||||
|
// instead, in evaluate, topo sort the graph and update invalid nodes?
|
||||||
|
for source in edge_sources {
|
||||||
|
self.update_node(source);
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = &mut self.node_graph.borrow_mut()[idx];
|
||||||
|
// Actually update the node's value.
|
||||||
|
(node.update)(&mut node.any);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn evaluate(&mut self) -> Output {
|
||||||
|
self.update_node(self.output);
|
||||||
|
let graph = self.node_graph.borrow();
|
||||||
|
let node = &graph[self.output].expect_type::<Output>();
|
||||||
|
node.value_rc().borrow().clone().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<Output: Clone + 'static> FrozenGraph<Output, Asynchronous> {
|
||||||
|
async fn update_node_async(&mut self, idx: NodeIndex<u32>) {
|
||||||
|
// TODO: same note about recursing as above, and consider doing this in parallel
|
||||||
|
let graph = self.node_graph.borrow();
|
||||||
|
let node = &graph[idx];
|
||||||
|
if !node.is_valid() {
|
||||||
|
let edge_sources = graph
|
||||||
|
.edges_directed(idx, petgraph::Direction::Incoming)
|
||||||
|
.map(|edge| edge.source())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
drop(graph);
|
||||||
|
|
||||||
|
for source in edge_sources {
|
||||||
|
Box::pin(self.update_node_async(source)).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let node = &self.node_graph.borrow()[idx];
|
||||||
|
(node.update)(Rc::clone(&node.any)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn evaluate_async(&mut self) -> Output {
|
||||||
|
self.update_node_async(self.output).await;
|
||||||
|
let graph = self.node_graph.borrow();
|
||||||
|
let node = &graph[self.output].expect_type::<Output>();
|
||||||
|
node.value_rc().borrow().clone().unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct Input<T> {
|
pub struct Input<T> {
|
||||||
node_idx: NodeIndex<u32>,
|
node_idx: NodeIndex<u32>,
|
||||||
@ -219,7 +382,7 @@ pub struct Input<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Clone + 'static> Input<T> {
|
impl<T: Clone + 'static> Input<T> {
|
||||||
fn value(&self) -> T {
|
pub fn value(&self) -> T {
|
||||||
self.value
|
self.value
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.borrow()
|
.borrow()
|
||||||
@ -229,24 +392,27 @@ impl<T: Clone + 'static> Input<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: there's a lot happening here, make sure this doesn't create a reference cycle
|
// TODO: there's a lot happening here, make sure this doesn't create a reference cycle
|
||||||
pub struct InvalidationSignal {
|
pub struct InvalidationSignal<Synch: Synchronicity> {
|
||||||
node_idx: Rc<Cell<Option<NodeIndex<u32>>>>,
|
node_idx: Rc<Cell<Option<NodeIndex<u32>>>>,
|
||||||
graph: Rc<RefCell<NodeGraph>>,
|
graph: Rc<RefCell<NodeGraph<Synch>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl InvalidationSignal {
|
impl<S: Synchronicity> InvalidationSignal<S> {
|
||||||
pub fn invalidate(&self) {
|
pub fn invalidate(&self) {
|
||||||
let mut queue = VecDeque::new();
|
let mut queue = VecDeque::new();
|
||||||
queue.push_back(self.node_idx.get().unwrap());
|
queue.push_back(self.node_idx.get().unwrap());
|
||||||
invalidate_nodes(&mut *self.graph.borrow_mut(), queue);
|
invalidate_nodes::<S>(&mut *self.graph.borrow_mut(), queue);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque<NodeIndex<u32>>) {
|
fn invalidate_nodes<S: Synchronicity>(
|
||||||
|
graph: &mut NodeGraph<S>,
|
||||||
|
mut queue: VecDeque<NodeIndex<u32>>,
|
||||||
|
) {
|
||||||
while let Some(idx) = queue.pop_front() {
|
while let Some(idx) = queue.pop_front() {
|
||||||
let node = &mut graph[idx];
|
let node = &mut graph[idx];
|
||||||
if (node.is_valid)(&node.any) {
|
if node.is_valid() {
|
||||||
(node.invalidate)(&mut node.any);
|
node.invalidate();
|
||||||
let dependents = graph
|
let dependents = graph
|
||||||
.edges_directed(idx, petgraph::Direction::Outgoing)
|
.edges_directed(idx, petgraph::Direction::Outgoing)
|
||||||
.map(|edge| edge.target());
|
.map(|edge| edge.target());
|
||||||
@ -257,62 +423,74 @@ fn invalidate_nodes(graph: &mut NodeGraph, mut queue: VecDeque<NodeIndex<u32>>)
|
|||||||
|
|
||||||
// TODO: i really want Input to be able to implement Deref somehow
|
// TODO: i really want Input to be able to implement Deref somehow
|
||||||
|
|
||||||
struct ErasedNode {
|
pub struct ErasedNode<Synch: Synchronicity> {
|
||||||
any: Box<dyn Any>,
|
any: Synch::AnyStorage,
|
||||||
is_valid: Box<dyn Fn(&Box<dyn Any>) -> bool>,
|
is_valid: Box<dyn Fn(&Synch::AnyStorage) -> bool>,
|
||||||
invalidate: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
|
invalidate: Box<dyn Fn(&mut Synch::AnyStorage) -> ()>,
|
||||||
visit_inputs: Box<dyn Fn(&mut Box<dyn Any>, &mut dyn FnMut(NodeIndex<u32>) -> ()) -> ()>,
|
visit_inputs: Box<dyn Fn(&mut Synch::AnyStorage, &mut dyn FnMut(NodeIndex<u32>) -> ()) -> ()>,
|
||||||
update: Box<dyn Fn(&mut Box<dyn Any>) -> ()>,
|
update: Synch::UpdateFn,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ErasedNode {
|
impl<S: Synchronicity> ErasedNode<S> {
|
||||||
fn new<N: Node<V> + 'static, V: 'static>(base: N) -> Self {
|
fn new<N: Node<V, S> + 'static, V: 'static>(base: N) -> Self {
|
||||||
// i don't love the double boxing, but i'm not sure how else to do this
|
// i don't love the double boxing, but i'm not sure how else to do this
|
||||||
let thing: Box<dyn Node<V>> = Box::new(base);
|
let thing: Box<dyn Node<V, S>> = Box::new(base);
|
||||||
let any: Box<dyn Any> = Box::new(thing);
|
|
||||||
Self {
|
Self {
|
||||||
any,
|
any: S::make_any_storage(thing),
|
||||||
is_valid: Box::new(|any| {
|
is_valid: Box::new(|any| {
|
||||||
let x = any.downcast_ref::<Box<dyn Node<V>>>().unwrap();
|
let x = S::unbox_any_storage::<Box<dyn Node<V, S>>>(any);
|
||||||
x.is_valid()
|
x.is_valid()
|
||||||
}),
|
}),
|
||||||
invalidate: Box::new(|any| {
|
invalidate: Box::new(|any| {
|
||||||
let x = any.downcast_mut::<Box<dyn Node<V>>>().unwrap();
|
let mut x = S::unbox_any_storage_mut::<Box<dyn Node<V, S>>>(any);
|
||||||
x.invalidate();
|
x.invalidate();
|
||||||
}),
|
}),
|
||||||
visit_inputs: Box::new(|any, visitor| {
|
visit_inputs: Box::new(|any, visitor| {
|
||||||
let x = any.downcast_mut::<Box<dyn Node<V>>>().unwrap();
|
let mut x = S::unbox_any_storage_mut::<Box<dyn Node<V, S>>>(any);
|
||||||
x.visit_inputs(visitor);
|
x.visit_inputs(visitor);
|
||||||
}),
|
}),
|
||||||
update: Box::new(|any| {
|
update: S::make_update_fn::<V>(),
|
||||||
let x = any.downcast_mut::<Box<dyn Node<V>>>().unwrap();
|
|
||||||
x.update();
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: revisit if these are necessary
|
fn expect_type<'a, V: 'static>(&'a self) -> impl Deref<Target = Box<dyn Node<V, S>>> + 'a {
|
||||||
fn expect_type<'a, V: 'static>(&'a self) -> &'a dyn Node<V> {
|
S::unbox_any_storage::<Box<dyn Node<V, S>>>(&self.any)
|
||||||
let res = self
|
}
|
||||||
.any
|
|
||||||
.downcast_ref::<Box<dyn Node<V>>>()
|
fn is_valid(&self) -> bool {
|
||||||
.expect("matching node type");
|
(self.is_valid)(&self.any)
|
||||||
res.as_ref()
|
}
|
||||||
|
fn invalidate(&mut self) {
|
||||||
|
(self.invalidate)(&mut self.any);
|
||||||
|
}
|
||||||
|
fn visit_inputs(&mut self, f: &mut dyn FnMut(NodeIndex<u32>) -> ()) {
|
||||||
|
(self.visit_inputs)(&mut self.any, f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
trait Node<Value> {
|
trait Node<Value: 'static, Synch: Synchronicity> {
|
||||||
fn is_valid(&self) -> bool;
|
fn is_valid(&self) -> bool;
|
||||||
fn invalidate(&mut self);
|
fn invalidate(&mut self);
|
||||||
fn visit_inputs(&mut self, visitor: &mut dyn FnMut(NodeIndex<u32>) -> ());
|
fn visit_inputs(&mut self, visitor: &mut dyn FnMut(NodeIndex<u32>) -> ());
|
||||||
fn update(&mut self);
|
fn update(&mut self) -> Synch::UpdateResult<'_>;
|
||||||
// TODO: are these both necessary?
|
|
||||||
fn value_rc(&self) -> Rc<RefCell<Option<Value>>>;
|
fn value_rc(&self) -> Rc<RefCell<Option<Value>>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConstNode<V>(V);
|
struct ConstNode<V, S> {
|
||||||
|
value: V,
|
||||||
|
synchronicity: std::marker::PhantomData<S>,
|
||||||
|
}
|
||||||
|
|
||||||
impl<V: Clone + 'static> Node<V> for ConstNode<V> {
|
impl<V, S> ConstNode<V, S> {
|
||||||
|
fn new(value: V) -> Self {
|
||||||
|
Self {
|
||||||
|
value,
|
||||||
|
synchronicity: std::marker::PhantomData,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<V: Clone + 'static, S: Synchronicity> Node<V, S> for ConstNode<V, S> {
|
||||||
fn is_valid(&self) -> bool {
|
fn is_valid(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
@ -321,30 +499,34 @@ impl<V: Clone + 'static> Node<V> for ConstNode<V> {
|
|||||||
|
|
||||||
fn visit_inputs(&mut self, _visitor: &mut dyn FnMut(NodeIndex<u32>) -> ()) {}
|
fn visit_inputs(&mut self, _visitor: &mut dyn FnMut(NodeIndex<u32>) -> ()) {}
|
||||||
|
|
||||||
fn update(&mut self) {}
|
fn update(&mut self) -> <S as Synchronicity>::UpdateResult<'_> {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
|
||||||
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
|
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
|
||||||
Rc::new(RefCell::new(Some(self.0.clone())))
|
Rc::new(RefCell::new(Some(self.value.clone())))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct RuleNode<R, V> {
|
struct RuleNode<R, V, S> {
|
||||||
rule: R,
|
rule: R,
|
||||||
value: Rc<RefCell<Option<V>>>,
|
value: Rc<RefCell<Option<V>>>,
|
||||||
valid: bool,
|
valid: bool,
|
||||||
|
synchronicity: std::marker::PhantomData<S>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: Rule<V>, V> RuleNode<R, V> {
|
impl<R: Rule<V>, V, S> RuleNode<R, V, S> {
|
||||||
fn new(rule: R) -> Self {
|
fn new(rule: R) -> Self {
|
||||||
Self {
|
Self {
|
||||||
rule,
|
rule,
|
||||||
value: Rc::new(RefCell::new(None)),
|
value: Rc::new(RefCell::new(None)),
|
||||||
valid: false,
|
valid: false,
|
||||||
|
synchronicity: std::marker::PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<R: Rule<V> + 'static, V: Clone + 'static> Node<V> for RuleNode<R, V> {
|
impl<R: Rule<V> + 'static, V: Clone + 'static, S: Synchronicity> Node<V, S> for RuleNode<R, V, S> {
|
||||||
fn is_valid(&self) -> bool {
|
fn is_valid(&self) -> bool {
|
||||||
self.valid
|
self.valid
|
||||||
}
|
}
|
||||||
@ -363,10 +545,11 @@ impl<R: Rule<V> + 'static, V: Clone + 'static> Node<V> for RuleNode<R, V> {
|
|||||||
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
|
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update(&mut self) {
|
fn update(&mut self) -> <S as Synchronicity>::UpdateResult<'_> {
|
||||||
self.valid = true;
|
|
||||||
let new_value = self.rule.evaluate();
|
let new_value = self.rule.evaluate();
|
||||||
|
self.valid = true;
|
||||||
*self.value.borrow_mut() = Some(new_value);
|
*self.value.borrow_mut() = Some(new_value);
|
||||||
|
S::make_update_result()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
|
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
|
||||||
@ -374,12 +557,70 @@ impl<R: Rule<V> + 'static, V: Clone + 'static> Node<V> for RuleNode<R, V> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AsyncRuleNode<R, V> {
|
||||||
|
rule: R,
|
||||||
|
value: Rc<RefCell<Option<V>>>,
|
||||||
|
valid: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
|
||||||
|
fn new(rule: R) -> Self {
|
||||||
|
Self {
|
||||||
|
rule,
|
||||||
|
value: Rc::new(RefCell::new(None)),
|
||||||
|
valid: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRule<V> + 'static, V: Clone + 'static> Node<V, Asynchronous> for AsyncRuleNode<R, V> {
|
||||||
|
fn is_valid(&self) -> bool {
|
||||||
|
self.valid
|
||||||
|
}
|
||||||
|
|
||||||
|
fn invalidate(&mut self) {
|
||||||
|
self.valid = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_inputs(&mut self, visitor: &mut dyn FnMut(NodeIndex<u32>) -> ()) {
|
||||||
|
struct InputIndexVisitor<'a>(&'a mut dyn FnMut(NodeIndex<u32>) -> ());
|
||||||
|
impl<'a> InputVisitor for InputIndexVisitor<'a> {
|
||||||
|
fn visit<T>(&mut self, input: &mut Input<T>) {
|
||||||
|
self.0(input.node_idx);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.rule.visit_inputs(&mut InputIndexVisitor(visitor));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn update(&mut self) -> <Asynchronous as Synchronicity>::UpdateResult<'_> {
|
||||||
|
Box::pin(self.do_update())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn value_rc(&self) -> Rc<RefCell<Option<V>>> {
|
||||||
|
Rc::clone(&self.value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: AsyncRule<V>, V> AsyncRuleNode<R, V> {
|
||||||
|
async fn do_update(&mut self) {
|
||||||
|
let new_value = self.rule.evaluate().await;
|
||||||
|
self.valid = true;
|
||||||
|
*self.value.borrow_mut() = Some(new_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait Rule<Output> {
|
pub trait Rule<Output> {
|
||||||
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
|
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
|
||||||
|
|
||||||
fn evaluate(&mut self) -> Output;
|
fn evaluate(&mut self) -> Output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait AsyncRule<Output> {
|
||||||
|
fn visit_inputs(&mut self, visitor: &mut impl InputVisitor);
|
||||||
|
|
||||||
|
async fn evaluate(&mut self) -> Output;
|
||||||
|
}
|
||||||
|
|
||||||
pub trait InputVisitor {
|
pub trait InputVisitor {
|
||||||
fn visit<T>(&mut self, input: &mut Input<T>);
|
fn visit<T>(&mut self, input: &mut Input<T>);
|
||||||
}
|
}
|
||||||
@ -390,7 +631,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn erase_node() {
|
fn erase_node() {
|
||||||
let node = ErasedNode::new(ConstNode(1234 as i32));
|
let node = ErasedNode::<Synchronous>::new(ConstNode::new(1234 as i32));
|
||||||
let unwrapped = node.expect_type::<i32>();
|
let unwrapped = node.expect_type::<i32>();
|
||||||
assert_eq!(unwrapped.value_rc().borrow().unwrap(), 1234);
|
assert_eq!(unwrapped.value_rc().borrow().unwrap(), 1234);
|
||||||
}
|
}
|
||||||
@ -512,7 +753,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cant_freeze_no_output() {
|
fn cant_freeze_no_output() {
|
||||||
let graph = Graph::<i32>::new();
|
let graph = Graph::<i32, Synchronous>::new();
|
||||||
match graph.freeze() {
|
match graph.freeze() {
|
||||||
Err(GraphFreezeError::NoOutput) => (),
|
Err(GraphFreezeError::NoOutput) => (),
|
||||||
Err(e) => assert!(false, "unexpected error {:?}", e),
|
Err(e) => assert!(false, "unexpected error {:?}", e),
|
||||||
@ -579,4 +820,27 @@ mod tests {
|
|||||||
assert!(!frozen.is_output_valid());
|
assert!(!frozen.is_output_valid());
|
||||||
assert_eq!(frozen.evaluate(), 2);
|
assert_eq!(frozen.evaluate(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn async_graph() {
|
||||||
|
let mut graph = Graph::new_async();
|
||||||
|
graph.set_output(ConstantRule(42));
|
||||||
|
let mut frozen = graph.freeze().unwrap();
|
||||||
|
assert_eq!(frozen.evaluate_async().await, 42);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn async_rule() {
|
||||||
|
struct AsyncConst(i32);
|
||||||
|
impl AsyncRule<i32> for AsyncConst {
|
||||||
|
fn visit_inputs(&mut self, _visitor: &mut impl InputVisitor) {}
|
||||||
|
async fn evaluate(&mut self) -> i32 {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let mut graph = Graph::new_async();
|
||||||
|
graph.set_async_output(AsyncConst(42));
|
||||||
|
let mut frozen = graph.freeze().unwrap();
|
||||||
|
assert_eq!(frozen.evaluate_async().await, 42);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user