paiagram/
intervals.rs

1use crate::{
2    units::speed::Velocity,
3    vehicles::{
4        AdjustTimetableEntry, AdjustVehicle, TimetableAdjustment, VehicleAdjustment,
5        entries::{ActualRouteEntry, TimetableEntry, VehicleScheduleCache},
6    },
7};
8use bevy::{
9    ecs::{entity::EntityHashMap, resource},
10    prelude::*,
11};
12use egui_graphs::to_graph;
13use petgraph::{self, Directed, csr::DefaultIx, graph::NodeIndex, prelude::StableGraph};
14
15pub type IntervalGraphType = StableGraph<Entity, Entity, Directed>;
16
17/// A graph representing the transportation network
18#[derive(Resource)]
19pub struct Graph {
20    pub inner: IntervalGraphType,
21    pub indexes: EntityHashMap<NodeIndex>,
22}
23
24impl Default for Graph {
25    fn default() -> Self {
26        Self {
27            inner: IntervalGraphType::new(),
28            indexes: EntityHashMap::default(),
29        }
30    }
31}
32
33#[derive(Resource, Deref, DerefMut)]
34pub struct UiGraph(
35    pub  egui_graphs::Graph<
36        Entity,
37        Entity,
38        Directed,
39        DefaultIx,
40        egui_graphs::DefaultNodeShape,
41        egui_graphs::DefaultEdgeShape,
42    >,
43);
44
45impl FromWorld for UiGraph {
46    fn from_world(world: &mut World) -> Self {
47        Self(to_graph(&world.resource::<Graph>().inner))
48    }
49}
50
51impl Graph {
52    pub fn edge_weight(&self, from: Entity, to: Entity) -> Option<&Entity> {
53        let from_idx = *self.indexes.get(&from)?;
54        let to_idx = *self.indexes.get(&to)?;
55        self.inner
56            .find_edge(from_idx, to_idx)
57            .and_then(|edge_idx| self.inner.edge_weight(edge_idx))
58    }
59    pub fn node(&self, entity: Entity) -> Option<NodeIndex> {
60        self.indexes.get(&entity).copied()
61    }
62    pub fn contains_node(&self, node: Entity) -> bool {
63        self.indexes.contains_key(&node)
64    }
65    pub fn contains_edge(&self, from: Entity, to: Entity) -> bool {
66        let from_idx = match self.indexes.get(&from) {
67            Some(idx) => *idx,
68            None => return false,
69        };
70        let to_idx = match self.indexes.get(&to) {
71            Some(idx) => *idx,
72            None => return false,
73        };
74        self.inner.find_edge(from_idx, to_idx).is_some()
75    }
76    pub fn entity(&self, node: NodeIndex) -> Option<Entity> {
77        for (entity, &idx) in self.indexes.iter() {
78            if idx == node {
79                return Some(*entity);
80            }
81        }
82        None
83    }
84}
85
86#[derive(Message)]
87pub enum GraphAdjustment {
88    AddEdge(GraphAdjustmentEdgeAddition),
89    RemoveEdge(Entity),
90    AddNode(Entity),
91}
92
93pub struct GraphAdjustmentEdgeAddition {
94    from: Entity,
95    to: Entity,
96    weight: Entity,
97}
98
99/// A station or node in the transportation network
100#[derive(Component)]
101#[require(Name, StationCache)]
102pub struct Station;
103
104#[derive(Component, Debug, Default)]
105pub struct StationCache {
106    pub passing_entries: Vec<Entity>,
107}
108
109impl StationCache {
110    /// WARNING: this method does not automatically clear vehicle entities. Clear before calling
111    /// This is for chaining
112    pub fn passing_vehicles<'a, F>(&self, buffer: &mut Vec<Entity>, mut get_parent: F)
113    where
114        F: FnMut(Entity) -> Option<&'a ChildOf>,
115    {
116        for entity in self.passing_entries.iter().cloned() {
117            let Some(vehicle) = get_parent(entity) else {
118                continue;
119            };
120            buffer.push(vehicle.0)
121        }
122    }
123}
124
125/// A depot or yard in the transportation network
126/// A depot cannot be a node in the transportation network graph. Use `Station` for that.
127#[derive(Component)]
128#[require(Name)]
129pub struct Depot;
130
131/// A track segment between two stations or nodes
132#[derive(Component)]
133#[require(Name, IntervalCache)]
134pub struct Interval {
135    /// The length of the track segment
136    pub length: crate::units::distance::Distance,
137    /// The speed limit on the track segment, if any
138    pub speed_limit: Option<Velocity>,
139}
140
141#[derive(Component, Debug, Default)]
142pub struct IntervalCache {
143    // start of the interval.
144    pub passing_entries: Vec<ActualRouteEntry>,
145}
146
147impl IntervalCache {
148    /// WARNING: this method does not automatically clear vehicle entities. Clear before calling
149    /// This is for chaining
150    pub fn passing_vehicles<'a, F>(&self, buffer: &mut Vec<Entity>, mut get_parent: F)
151    where
152        F: FnMut(Entity) -> Option<&'a ChildOf>,
153    {
154        for entity in self.passing_entries.iter().cloned() {
155            let Some(vehicle) = get_parent(entity.inner()) else {
156                continue;
157            };
158            buffer.push(vehicle.0)
159        }
160    }
161}
162
163pub struct IntervalsPlugin;
164impl Plugin for IntervalsPlugin {
165    fn build(&self, app: &mut App) {
166        app.insert_resource(Graph::default())
167            .init_resource::<IntervalsResource>()
168            .init_resource::<UiGraph>()
169            .add_systems(
170                FixedPostUpdate,
171                (
172                    update_station_cache.run_if(on_message::<AdjustTimetableEntry>),
173                    update_interval_cache,
174                    update_ui_graph,
175                ),
176            );
177    }
178}
179
180#[derive(Resource)]
181pub struct IntervalsResource {
182    pub default_depot: Entity,
183}
184
185impl FromWorld for IntervalsResource {
186    fn from_world(world: &mut World) -> Self {
187        // create a depot once and stash the entity so callers can rely on it existing
188        let default_depot = world.spawn((Name::new("Default Depot"), Depot)).id();
189        Self { default_depot }
190    }
191}
192
193fn update_ui_graph(
194    mut graph: Res<Graph>,
195    mut ui_graph: ResMut<UiGraph>,
196    station_names: Query<&Name, With<Station>>,
197    intervals: Query<Ref<Interval>>,
198) {
199    if !(graph.is_changed() || intervals.iter().any(|i| i.is_changed())) {
200        return;
201    }
202    let node_transform = |n: &mut egui_graphs::Node<Entity, Entity>| {
203        // show the name of the station if available
204        if let Ok(name) = station_names.get(n.props().payload) {
205            n.set_label(name.as_str().to_string());
206        }
207    };
208    let edge_transform = |e: &mut egui_graphs::Edge<Entity, Entity>| {};
209    ui_graph.0 = egui_graphs::to_graph_custom(&graph.inner, node_transform, edge_transform);
210}
211
212fn update_station_cache(
213    mut msg_entry_change: MessageReader<AdjustTimetableEntry>,
214    mut msg_schedule_change: MessageReader<AdjustVehicle>,
215    timetable_entries: Query<&TimetableEntry>,
216    mut station_caches: Query<&mut StationCache>,
217) {
218    for msg in msg_entry_change.read() {
219        let Ok(entry) = timetable_entries.get(msg.entity) else {
220            continue;
221        };
222        let Ok(mut current_station_cache) = station_caches.get_mut(entry.station) else {
223            continue;
224        };
225        let index = current_station_cache
226            .passing_entries
227            .binary_search(&msg.entity);
228        match (&msg.adjustment, index) {
229            (&TimetableAdjustment::SetStation(_new_station), Ok(index)) => {
230                current_station_cache.passing_entries.remove(index);
231            }
232            (_, Err(index)) => {
233                current_station_cache
234                    .passing_entries
235                    .insert(index, msg.entity);
236            }
237            _ => {}
238        }
239    }
240    for entity in msg_schedule_change.read().filter_map(|msg| {
241        if let VehicleAdjustment::RemoveEntry(entity) = msg.adjustment {
242            Some(entity)
243        } else {
244            None
245        }
246    }) {
247        if let Ok(entry) = timetable_entries.get(entity)
248            && let Ok(mut station_cache) = station_caches.get_mut(entry.station)
249            && let Ok(index) = station_cache.passing_entries.binary_search(&entity)
250        {
251            station_cache.passing_entries.remove(index);
252        };
253    }
254}
255
256pub fn update_interval_cache(
257    changed_schedules: Populated<&VehicleScheduleCache, Changed<VehicleScheduleCache>>,
258    mut intervals: Query<&mut IntervalCache>,
259    timetable_entries: Query<&TimetableEntry>,
260    graph: Res<Graph>,
261    mut invalidated: Local<Vec<Entity>>,
262) {
263    invalidated.clear();
264    for schedule in changed_schedules {
265        let Some(actual_route) = &schedule.actual_route else {
266            continue;
267        };
268        for w in actual_route.windows(2) {
269            let [beg, end] = w else { continue };
270            let Ok(beg_entry) = timetable_entries.get(beg.inner()) else {
271                continue;
272            };
273            let Ok(end_entry) = timetable_entries.get(end.inner()) else {
274                continue;
275            };
276            let Some(&edge) = graph.edge_weight(beg_entry.station, end_entry.station) else {
277                continue;
278            };
279            let Ok(mut cache) = intervals.get_mut(edge) else {
280                continue;
281            };
282            // now that we have the cache, invalidate the cache first
283            if !invalidated.contains(&edge) {
284                cache
285                    .passing_entries
286                    .retain(|e| matches!(e, ActualRouteEntry::Nominal(_)));
287                invalidated.push(edge)
288            }
289            cache.passing_entries.push(*end);
290        }
291    }
292    for invalidated in invalidated.iter().copied() {
293        let Ok(mut cache) = intervals.get_mut(invalidated) else {
294            continue;
295        };
296        cache.passing_entries.sort_unstable_by_key(|e| e.inner());
297        cache.passing_entries.dedup_by_key(|e| e.inner());
298    }
299}