1use crate::{
2 units::speed::Velocity,
3 vehicles::{
4 AdjustTimetableEntry, AdjustVehicle, TimetableAdjustment, VehicleAdjustment,
5 entries::{ActualRouteEntry, TimetableEntry, VehicleScheduleCache},
6 },
7};
8use bevy::{
9 ecs::{entity::EntityHashMap, resource},
10 prelude::*,
11};
12use egui_graphs::to_graph;
13use petgraph::{self, Directed, csr::DefaultIx, graph::NodeIndex, prelude::StableGraph};
14
15pub type IntervalGraphType = StableGraph<Entity, Entity, Directed>;
16
17#[derive(Resource)]
19pub struct Graph {
20 pub inner: IntervalGraphType,
21 pub indexes: EntityHashMap<NodeIndex>,
22}
23
24impl Default for Graph {
25 fn default() -> Self {
26 Self {
27 inner: IntervalGraphType::new(),
28 indexes: EntityHashMap::default(),
29 }
30 }
31}
32
33#[derive(Resource, Deref, DerefMut)]
34pub struct UiGraph(
35 pub egui_graphs::Graph<
36 Entity,
37 Entity,
38 Directed,
39 DefaultIx,
40 egui_graphs::DefaultNodeShape,
41 egui_graphs::DefaultEdgeShape,
42 >,
43);
44
45impl FromWorld for UiGraph {
46 fn from_world(world: &mut World) -> Self {
47 Self(to_graph(&world.resource::<Graph>().inner))
48 }
49}
50
51impl Graph {
52 pub fn edge_weight(&self, from: Entity, to: Entity) -> Option<&Entity> {
53 let from_idx = *self.indexes.get(&from)?;
54 let to_idx = *self.indexes.get(&to)?;
55 self.inner
56 .find_edge(from_idx, to_idx)
57 .and_then(|edge_idx| self.inner.edge_weight(edge_idx))
58 }
59 pub fn node(&self, entity: Entity) -> Option<NodeIndex> {
60 self.indexes.get(&entity).copied()
61 }
62 pub fn contains_node(&self, node: Entity) -> bool {
63 self.indexes.contains_key(&node)
64 }
65 pub fn contains_edge(&self, from: Entity, to: Entity) -> bool {
66 let from_idx = match self.indexes.get(&from) {
67 Some(idx) => *idx,
68 None => return false,
69 };
70 let to_idx = match self.indexes.get(&to) {
71 Some(idx) => *idx,
72 None => return false,
73 };
74 self.inner.find_edge(from_idx, to_idx).is_some()
75 }
76 pub fn entity(&self, node: NodeIndex) -> Option<Entity> {
77 for (entity, &idx) in self.indexes.iter() {
78 if idx == node {
79 return Some(*entity);
80 }
81 }
82 None
83 }
84}
85
86#[derive(Message)]
87pub enum GraphAdjustment {
88 AddEdge(GraphAdjustmentEdgeAddition),
89 RemoveEdge(Entity),
90 AddNode(Entity),
91}
92
93pub struct GraphAdjustmentEdgeAddition {
94 from: Entity,
95 to: Entity,
96 weight: Entity,
97}
98
99#[derive(Component)]
101#[require(Name, StationCache)]
102pub struct Station;
103
104#[derive(Component, Debug, Default)]
105pub struct StationCache {
106 pub passing_entries: Vec<Entity>,
107}
108
109impl StationCache {
110 pub fn passing_vehicles<'a, F>(&self, buffer: &mut Vec<Entity>, mut get_parent: F)
113 where
114 F: FnMut(Entity) -> Option<&'a ChildOf>,
115 {
116 for entity in self.passing_entries.iter().cloned() {
117 let Some(vehicle) = get_parent(entity) else {
118 continue;
119 };
120 buffer.push(vehicle.0)
121 }
122 }
123}
124
125#[derive(Component)]
128#[require(Name)]
129pub struct Depot;
130
131#[derive(Component)]
133#[require(Name, IntervalCache)]
134pub struct Interval {
135 pub length: crate::units::distance::Distance,
137 pub speed_limit: Option<Velocity>,
139}
140
141#[derive(Component, Debug, Default)]
142pub struct IntervalCache {
143 pub passing_entries: Vec<ActualRouteEntry>,
145}
146
147impl IntervalCache {
148 pub fn passing_vehicles<'a, F>(&self, buffer: &mut Vec<Entity>, mut get_parent: F)
151 where
152 F: FnMut(Entity) -> Option<&'a ChildOf>,
153 {
154 for entity in self.passing_entries.iter().cloned() {
155 let Some(vehicle) = get_parent(entity.inner()) else {
156 continue;
157 };
158 buffer.push(vehicle.0)
159 }
160 }
161}
162
163pub struct IntervalsPlugin;
164impl Plugin for IntervalsPlugin {
165 fn build(&self, app: &mut App) {
166 app.insert_resource(Graph::default())
167 .init_resource::<IntervalsResource>()
168 .init_resource::<UiGraph>()
169 .add_systems(
170 FixedPostUpdate,
171 (
172 update_station_cache.run_if(on_message::<AdjustTimetableEntry>),
173 update_interval_cache,
174 update_ui_graph,
175 ),
176 );
177 }
178}
179
180#[derive(Resource)]
181pub struct IntervalsResource {
182 pub default_depot: Entity,
183}
184
185impl FromWorld for IntervalsResource {
186 fn from_world(world: &mut World) -> Self {
187 let default_depot = world.spawn((Name::new("Default Depot"), Depot)).id();
189 Self { default_depot }
190 }
191}
192
193fn update_ui_graph(
194 mut graph: Res<Graph>,
195 mut ui_graph: ResMut<UiGraph>,
196 station_names: Query<&Name, With<Station>>,
197 intervals: Query<Ref<Interval>>,
198) {
199 if !(graph.is_changed() || intervals.iter().any(|i| i.is_changed())) {
200 return;
201 }
202 let node_transform = |n: &mut egui_graphs::Node<Entity, Entity>| {
203 if let Ok(name) = station_names.get(n.props().payload) {
205 n.set_label(name.as_str().to_string());
206 }
207 };
208 let edge_transform = |e: &mut egui_graphs::Edge<Entity, Entity>| {};
209 ui_graph.0 = egui_graphs::to_graph_custom(&graph.inner, node_transform, edge_transform);
210}
211
212fn update_station_cache(
213 mut msg_entry_change: MessageReader<AdjustTimetableEntry>,
214 mut msg_schedule_change: MessageReader<AdjustVehicle>,
215 timetable_entries: Query<&TimetableEntry>,
216 mut station_caches: Query<&mut StationCache>,
217) {
218 for msg in msg_entry_change.read() {
219 let Ok(entry) = timetable_entries.get(msg.entity) else {
220 continue;
221 };
222 let Ok(mut current_station_cache) = station_caches.get_mut(entry.station) else {
223 continue;
224 };
225 let index = current_station_cache
226 .passing_entries
227 .binary_search(&msg.entity);
228 match (&msg.adjustment, index) {
229 (&TimetableAdjustment::SetStation(_new_station), Ok(index)) => {
230 current_station_cache.passing_entries.remove(index);
231 }
232 (_, Err(index)) => {
233 current_station_cache
234 .passing_entries
235 .insert(index, msg.entity);
236 }
237 _ => {}
238 }
239 }
240 for entity in msg_schedule_change.read().filter_map(|msg| {
241 if let VehicleAdjustment::RemoveEntry(entity) = msg.adjustment {
242 Some(entity)
243 } else {
244 None
245 }
246 }) {
247 if let Ok(entry) = timetable_entries.get(entity)
248 && let Ok(mut station_cache) = station_caches.get_mut(entry.station)
249 && let Ok(index) = station_cache.passing_entries.binary_search(&entity)
250 {
251 station_cache.passing_entries.remove(index);
252 };
253 }
254}
255
256pub fn update_interval_cache(
257 changed_schedules: Populated<&VehicleScheduleCache, Changed<VehicleScheduleCache>>,
258 mut intervals: Query<&mut IntervalCache>,
259 timetable_entries: Query<&TimetableEntry>,
260 graph: Res<Graph>,
261 mut invalidated: Local<Vec<Entity>>,
262) {
263 invalidated.clear();
264 for schedule in changed_schedules {
265 let Some(actual_route) = &schedule.actual_route else {
266 continue;
267 };
268 for w in actual_route.windows(2) {
269 let [beg, end] = w else { continue };
270 let Ok(beg_entry) = timetable_entries.get(beg.inner()) else {
271 continue;
272 };
273 let Ok(end_entry) = timetable_entries.get(end.inner()) else {
274 continue;
275 };
276 let Some(&edge) = graph.edge_weight(beg_entry.station, end_entry.station) else {
277 continue;
278 };
279 let Ok(mut cache) = intervals.get_mut(edge) else {
280 continue;
281 };
282 if !invalidated.contains(&edge) {
284 cache
285 .passing_entries
286 .retain(|e| matches!(e, ActualRouteEntry::Nominal(_)));
287 invalidated.push(edge)
288 }
289 cache.passing_entries.push(*end);
290 }
291 }
292 for invalidated in invalidated.iter().copied() {
293 let Ok(mut cache) = intervals.get_mut(invalidated) else {
294 continue;
295 };
296 cache.passing_entries.sort_unstable_by_key(|e| e.inner());
297 cache.passing_entries.dedup_by_key(|e| e.inner());
298 }
299}