Modeling Disruptions in the MTA Subway
Sr. Manager, Technical Product Marketing, Neo4j
7 min read

Create a free graph database instance in Neo4j AuraDB
Roughly 3.6 million Americans ride the New York City subway each day. That’s 1.3 million more people than who take flights every day in the United States. The New York City Metropolitan Transportation Authority (MTA) subway system is unique in the number of riders, stations, and the level of service it provides.
Using Aura Graph Analytics, we can easily model what would happen if a station was fully closed for repairs. Insights like this can apply not just to transport systems but to supply chains, manufacturing processes, and much more.
For supply chains, imagine if a particular vendor was hit with tariffs or, even worse, went out of business. Which alternate supplier path gets your product to market fastest with the least disruption? Using the same shortest-path techniques demonstrated in this blog, you can quickly evaluate the next best route and keep operations running smoothly.
Aura Graph Analytics works with any enterprise data. In this example, we’re going to load data from Snowflake into dataframes, create a graph projection, and run algorithms – all without having to move our data into AuraDB!
Whether it’s passengers or products, understanding and adapting the paths through your network is key to resilience.
Getting Started
We’re going to work in a Google Colab notebook, but you can run this in any Python environment. First, we need to install and load the necessary packages:
!pip install graphdatascience
from graphdatascience.session import GdsSessions, AuraAPICredentials, DbmsConnectionInfo, AlgorithmCategory
from datetime import timedelta
import pandas as pd
import os
from google.colab import userdataCode language: Python (python)
Create a Session
Next, we set up a session, first by loading in our secrets:
CLIENT_ID = userdata.get("CLIENT_ID")
CLIENT_SECRET = userdata.get("CLIENT_SECRET")
TENANT_ID = userdata.get("TENANT_ID")Code language: Python (python)
Then by establishing a session:
from graphdatascience.session import GdsSessions, AuraAPICredentials, AlgorithmCategory, CloudLocation
from datetime import timedelta
sessions = GdsSessions(api_credentials=AuraAPICredentials(CLIENT_ID, CLIENT_SECRET, TENANT_ID))
name = "my-new-session-subway"
memory = sessions.estimate(
node_count=20,
relationship_count=50,
algorithm_categories=[AlgorithmCategory.CENTRALITY, AlgorithmCategory.NODE_EMBEDDING],
)
cloud_location = CloudLocation(provider="gcp", region="europe-west1")
gds = sessions.get_or_create(
session_name=name,
memory=memory,
ttl=timedelta(hours=5),
cloud_location=cloud_location,
)Code language: Python (python)
Load Data From Snowflake
Load this data into Snowflake or directly into your Python environment.
A key advantage of Aura Graph Analytics is that you don’t need to store your data in AuraDB to use it. In our case, we’ll load data from Snowflake into Python dataframes. Let’s start by downloading the snowflake-connector-python package:
!pip install snowflake-connector-python
Then we create a connection to Snowflake:
import pandas as pd
import snowflake.connector
SNOWFLAKE_USER = userdata.get("snowflake_user")
SNOWFLAKE_PASSWORD = userdata.get("snowflake_password")
SNOWFLAKE_ACCOUNT = userdata.get("snowflake_account")
# Replace with your credentials
conn = snowflake.connector.connect(
user= SNOWFLAKE_USER,
password=SNOWFLAKE_PASSWORD,
account=SNOWFLAKE_ACCOUNT,
warehouse='GDSONSNOWFLAKE',
database='MTA',
schema='PUBLIC',
)Code language: Python (python)
And return our two tables as Python dataframes:
cur = conn.cursor()
cur.execute("SELECT * FROM LINES")
lines = cur.fetch_pandas_all()
cur.close()
linesCode language: JavaScript (javascript)
| STARTING_STATION | NEXT_STATION | RELATIONSHIPTYPE |
| 0 | 1 | GOES_TO |
| 1 | 2 | GOES_TO |
| 2 | 3 | GOES_TO |
| 3 | 4 | GOES_TO |
cur = conn.cursor()
cur.execute("SELECT * FROM stations")
stations = cur.fetch_pandas_all()
cur.close()
stationsCode language: JavaScript (javascript)
| STATION_NAME | ID |
| Van Cortlandt Park-242 – Bx | 0 |
| 238 St – Bx | 1 |
| 231 St – Bx | 2 |
| Marble Hill-225 St – M | 3 |
| 215 St – M | 4 |
Creating a Projection
We need to do some mild cleanup to make sure everything has the right names.
For the dataframe representing nodes:
- The first column should be called
nodeId. - There can be no characters, so we have to drop the station names.
stations = stations.rename(columns={'id': 'nodeId'})
nodes = stations[['nodeId']]
nodesCode language: JavaScript (javascript)
For the dataframe representing relationships, we need to have columns called sourceNodeId and targetNodeId:
lines2 = lines.rename(
columns={
'STARTING_STATION' : 'targetNodeId',
'NEXT_STATION' : 'sourceNodeId'
}
)
lines = lines[['targetNodeId', 'sourceNodeId']]
linesCode language: JavaScript (javascript)
Graph Construct
Using graph.construct, we can easily create a projection:
graph_name = "subways"
if gds.graph.exists(graph_name)["exists"]:
# Drop the graph if it exists
gds.graph.drop(graph_name)
print(f"Graph '{graph_name}' dropped.")
G = gds.graph.construct("subways", nodes, lines)Code language: PHP (php)
We’ll use Dijkstra shortest path to see how we can move through the system efficiently. We can create a simple wrapper function below, so that we can use the names of stations rather than their nodeIds:
station_crosswalk = dict(zip(stations['STATION_NAME'], stations['nodeId']))
# Function to get the node IDs from station names and run Dijkstra
def get_shortest_path(source_station, target_station, G):
# Map the station names to node IDs
source_node_id = station_crosswalk.get(source_station)
target_node_id = station_crosswalk.get(target_station)
result = gds.shortestPath.dijkstra.stream(
G,
sourceNode=source_node_id,
targetNode=target_node_id
)
node_ids = result['nodeIds'][0]
id_to_station = {v: k for k, v in station_crosswalk.items()}
ordered_subset = {id_to_station[i]: i for i in node_ids if i in id_to_station}
return ordered_subsetCode language: PHP (php)
Let’s see how to get from Grand Army Plaza in Brooklyn to Times Square:
# Example usage
# Assuming 'G' is your graph
source_station = "Grand Army Plaza - Bk"
target_station = "Times Sq-42 St - M"
# Call the function
path = get_shortest_path(source_station, target_station, G)
pathCode language: PHP (php)
This returns:
{'Grand Army Plaza - Bk': 69,
'Bergen St - Bk': 68,
'Atlantic Av-Barclays Ctr - Bk': 67,
'Canal St - M': 32,
'14 St-Union Sq - M': 104,
'34 St-Herald Sq - M': 230,
'Times Sq-42 St - M': 24}Code language: JavaScript (javascript)
Modeling Disruptions
But what if one of those stations closed? What would be the quickest path there? Let’s see what would happen if Herald Square was closed:
def exclude_node(nodes_df, lines_df, node_to_exclude):
closed = nodes_df[nodes_df['nodeId'] != node_to_exclude]
closed_lines = lines_df[
(lines_df['sourceNodeId'] != node_to_exclude) &
(lines_df['targetNodeId'] != node_to_exclude)
]
return closed, closed_lines
closed_nodes, closed_lines = exclude_node(nodes, lines, 230)Code language: JavaScript (javascript)
We then need to create a new projection without Herald Square:
graph_name = "exclude"
if gds.graph.exists(graph_name)["exists"]:
# Drop the graph if it exists
gds.graph.drop(graph_name)
print(f"Graph '{graph_name}' dropped.")
G = gds.graph.construct(graph_name, closed_nodes, closed_lines)Code language: PHP (php)
Then we rerun our algorithm:
# Example usage
# Assuming 'G' is your graph
source_station = "Grand Army Plaza - Bk"
target_station = "Times Sq-42 St - M"
# Call the function
path = get_shortest_path(source_station, target_station, G)
pathCode language: PHP (php)
Which returns:
{'Grand Army Plaza - Bk': 69,
'Bergen St - Bk': 68,
'Atlantic Av-Barclays Ctr - Bk': 67,
'Canal St - M': 32,
'Chambers St - M': 34,
'14 St - M': 29,
'34 St-Penn Station - M': 25,
'Times Sq-42 St - M': 24}Code language: JavaScript (javascript)
We can see that this is a slightly longer path than before!
Finally, we end our session:
sessions.delete(session_name="my-new-session-subway")Code language: JavaScript (javascript)
And with that, you can see how to run graph algorithms against any enterprise data, and how to model disruptions!
Summary and Next Steps
You’ve seen how to run graph algorithms against any enterprise data and how to model disruptions.
So now that you’ve got a solid grasp on modeling disruptions (whether that be on the subway or otherwise), head over to our GitHub repo for step-by-step instructions on how to do it yourself with Neo4j Aura Graph Analytics. You’ll find a Colab notebook, the full dataset, and everything you need to get started.
Prefer working in Snowflake? You can run the same example there using Neo4j Graph Analytics for Snowflake.





