paiagram/
graph.rs

1use crate::graph::arrange::GraphLayoutTask;
2use crate::units::speed::Velocity;
3use crate::vehicles::entries::TimetableEntry;
4use bevy::ecs::entity::{EntityHashMap, EntityMapper, MapEntities};
5use bevy::prelude::*;
6use either::Either;
7use moonshine_core::kind::prelude::*;
8use moonshine_core::save::prelude::*;
9use petgraph::prelude::*;
10use serde::{Deserialize, Deserializer, Serialize, Serializer};
11
12pub mod arrange;
13
14/// The graph type used for the transportation network
15pub type IntervalGraphType = StableDiGraph<Instance<Station>, Instance<Interval>>;
16/// A raw graph type used for serialization/deserialization
17#[derive(Serialize, Deserialize, MapEntities)]
18pub struct RawIntervalGraphType(StableDiGraph<Entity, Entity>);
19
20/// A graph representing the transportation network
21#[derive(Reflect, Clone, Resource, Default, Debug)]
22#[reflect(Resource, opaque, Serialize, Deserialize, MapEntities)]
23pub struct Graph {
24    /// The inner graph structure
25    inner: IntervalGraphType,
26    /// Mapping from station entities to their node indices in the graph
27    /// This is skipped during serialization/deserialization and rebuilt as needed
28    #[reflect(ignore)]
29    indices: EntityHashMap<NodeIndex>,
30}
31
32impl From<RawIntervalGraphType> for IntervalGraphType {
33    fn from(value: RawIntervalGraphType) -> Self {
34        value.0.map(
35            |_, &node| unsafe { Instance::from_entity_unchecked(node) },
36            |_, &edge| unsafe { Instance::from_entity_unchecked(edge) },
37        )
38    }
39}
40
41impl From<IntervalGraphType> for RawIntervalGraphType {
42    fn from(value: IntervalGraphType) -> Self {
43        Self(value.map(|_, node| node.entity(), |_, edge| edge.entity()))
44    }
45}
46
47impl Serialize for Graph {
48    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
49    where
50        S: Serializer,
51    {
52        let raw: RawIntervalGraphType = self.inner.clone().into();
53        raw.serialize(serializer)
54    }
55}
56
57impl<'de> Deserialize<'de> for Graph {
58    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
59    where
60        D: Deserializer<'de>,
61    {
62        let raw = RawIntervalGraphType::deserialize(deserializer)?;
63        let inner: IntervalGraphType = raw.into();
64        Ok(Graph {
65            inner,
66            indices: EntityHashMap::default(),
67        })
68    }
69}
70
71pub struct EdgeReference {
72    pub weight: Instance<Interval>,
73    pub source: Instance<Station>,
74    pub target: Instance<Station>,
75}
76
77impl MapEntities for Graph {
78    fn map_entities<M: EntityMapper>(&mut self, entity_mapper: &mut M) {
79        for node in self.inner.node_weights_mut() {
80            node.map_entities(entity_mapper);
81        }
82        for edge in self.inner.edge_weights_mut() {
83            edge.map_entities(entity_mapper);
84        }
85        let mut new_indices = EntityHashMap::default();
86        // construct the indices from the graph instead.
87        for index in self.inner.node_indices() {
88            let station = &self.inner[index];
89            new_indices.insert(station.entity(), index);
90        }
91        self.indices = new_indices;
92    }
93}
94
95impl Graph {
96    pub fn inner(&self) -> &IntervalGraphType {
97        &self.inner
98    }
99    pub fn clear(&mut self) {
100        self.inner.clear();
101        self.indices.clear();
102    }
103    pub fn edge_weight(
104        &self,
105        a: Instance<Station>,
106        b: Instance<Station>,
107    ) -> Option<&Instance<Interval>> {
108        let &a_index = self.indices.get(&a.entity())?;
109        let &b_index = self.indices.get(&b.entity())?;
110        self.inner
111            .edge_weight(self.inner.find_edge(a_index, b_index)?)
112    }
113    pub fn contains_edge(&self, a: Instance<Station>, b: Instance<Station>) -> bool {
114        let Some(&a_index) = self.indices.get(&a.entity()) else {
115            return false;
116        };
117        let Some(&b_index) = self.indices.get(&b.entity()) else {
118            return false;
119        };
120        self.inner.find_edge(a_index, b_index).is_some()
121    }
122    pub fn contains_node(&self, a: Instance<Station>) -> bool {
123        self.indices.contains_key(&a.entity())
124    }
125    pub fn node_index(&self, a: Instance<Station>) -> Option<NodeIndex> {
126        self.indices.get(&a.entity()).cloned()
127    }
128    pub fn entity(&self, index: NodeIndex) -> Option<Instance<Station>> {
129        self.inner.node_weight(index).cloned()
130    }
131    pub fn add_edge(
132        &mut self,
133        a: Instance<Station>,
134        b: Instance<Station>,
135        edge: Instance<Interval>,
136    ) {
137        let a_index = if let Some(&index) = self.indices.get(&a.entity()) {
138            index
139        } else {
140            let index = self.inner.add_node(a);
141            self.indices.insert(a.entity(), index);
142            index
143        };
144        let b_index = if let Some(&index) = self.indices.get(&b.entity()) {
145            index
146        } else {
147            let index = self.inner.add_node(b);
148            self.indices.insert(b.entity(), index);
149            index
150        };
151        self.inner.add_edge(a_index, b_index, edge);
152    }
153    pub fn add_node(&mut self, a: Instance<Station>) {
154        if self.indices.contains_key(&a.entity()) {
155            return;
156        }
157        let index = self.inner.add_node(a);
158        self.indices.insert(a.entity(), index);
159    }
160    pub fn edges_connecting(
161        &self,
162        a: Instance<Station>,
163        b: Instance<Station>,
164    ) -> impl Iterator<Item = EdgeReference> {
165        let a_idx = match self.indices.get(&a.entity()) {
166            None => return Either::Left(std::iter::empty()),
167            Some(i) => i.clone(),
168        };
169        let b_idx = match self.indices.get(&b.entity()) {
170            None => return Either::Left(std::iter::empty()),
171            Some(i) => i.clone(),
172        };
173        let edge = self
174            .inner
175            .edges_connecting(a_idx, b_idx)
176            .map(|e| EdgeReference {
177                weight: *e.weight(),
178                source: self.inner[e.source()],
179                target: self.inner[e.target()],
180            });
181        Either::Right(edge)
182    }
183}
184
185/// A station or in the transportation network
186#[derive(Component, Default, Deref, DerefMut, Debug, Clone, Reflect, Serialize, Deserialize)]
187#[reflect(Component, opaque, Serialize, Deserialize)]
188#[require(Name, StationEntries, Save)]
189pub struct Station(pub egui::Pos2);
190
191#[derive(Reflect, Component, Debug, Default, MapEntities)]
192#[reflect(Component, MapEntities)]
193#[relationship_target(relationship = TimetableEntry)]
194pub struct StationEntries(Vec<Entity>);
195
196impl StationEntries {
197    pub fn entries(&self) -> &[Entity] {
198        &self.0
199    }
200    /// WARNING: this method does not automatically clear vehicle entities. Clear before calling
201    /// This is for chaining
202    pub fn passing_vehicles<'a, F>(&self, buffer: &mut Vec<Entity>, mut get_parent: F)
203    where
204        F: FnMut(Entity) -> Option<&'a ChildOf>,
205    {
206        for entity in self.0.iter().cloned() {
207            let Some(vehicle) = get_parent(entity) else {
208                continue;
209            };
210            buffer.push(vehicle.0)
211        }
212    }
213}
214
215/// A track segment between two stations or nodes
216#[derive(Reflect, Component)]
217#[reflect(Component)]
218#[require(Name, Save)]
219pub struct Interval {
220    /// The length of the track segment
221    pub length: crate::units::distance::Distance,
222    /// The speed limit on the track segment, if any
223    pub speed_limit: Option<Velocity>,
224}
225
226pub struct GraphPlugin;
227
228impl Plugin for GraphPlugin {
229    fn build(&self, app: &mut App) {
230        app.init_resource::<Graph>().add_systems(
231            Update,
232            arrange::apply_graph_layout.run_if(resource_exists::<GraphLayoutTask>),
233        );
234    }
235}