Aura Graph Analytics with Spark
This Jupyter notebook is hosted here in the Neo4j Graph Data Science Client Github repository.
The notebook shows how to use the graphdatascience Python library to
create, manage, and use a GDS Session from within an Apache Spark
cluster.
We consider a graph of bicycle rentals, which we’re using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark. In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details.
1. Prerequisites
We also need to have the graphdatascience Python library installed,
version 1.18 or later, as well as pyspark.
%pip install "graphdatascience>=1.18" python-dotenv "pyspark[sql]"
from dotenv import load_dotenv
# This allows to load required secrets from `.env` file in local directory
# This can include Aura API Credentials and Database Credentials.
# If file does not exist this is a noop.
load_dotenv("sessions.env")
1.1. Connecting to a Spark Session
To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine. Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[4]").appName("GraphAnalytics").getOrCreate()
# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")
2. Aura API credentials
The entry point for managing GDS Sessions is the GdsSessions object,
which requires creating
Aura API credentials.
import os
from graphdatascience.session import AuraAPICredentials, GdsSessions
# you can also use AuraAPICredentials.from_env() to load credentials from environment variables
api_credentials = AuraAPICredentials(
client_id=os.environ["CLIENT_ID"],
client_secret=os.environ["CLIENT_SECRET"],
# If your account is a member of several projects, you must also specify the project ID to use
project_id=os.environ.get("PROJECT_ID", None),
)
sessions = GdsSessions(api_credentials=api_credentials)
3. Creating a new session
A new session is created by calling sessions.get_or_create()
with the following parameters:
-
A session name, which lets you reconnect to an existing session by calling
get_or_createagain. -
The session memory.
-
The cloud location.
-
A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.
See the API reference documentation or the manual for more details on the parameters.
from datetime import timedelta
from graphdatascience.session import CloudLocation, SessionMemory
# Create a GDS session!
gds = sessions.get_or_create(
# we give it a representative name
session_name="bike_trips",
memory=SessionMemory.m_2GB,
ttl=timedelta(minutes=30),
cloud_location=CloudLocation("gcp", "europe-west1"),
)
# Verify the connectivity. Hints towards TLS or firewall issues if this fails directly after get_or_create
gds.verify_connectivity()
4. Adding a dataset
As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip.
import io
import os
import zipfile
import requests
download_path = "bike_trips_data"
if not os.path.exists(download_path):
url = "https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips"
response = requests.get(url)
response.raise_for_status()
# Unzip the content
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
z.extractall(download_path)
df = spark.read.csv(download_path, header=True, inferSchema=True)
df.createOrReplaceTempView("bike_trips")
df.limit(10).show()
5. Projecting Graphs
Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.
We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.
Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server’s "graph import from triplets" functionality, which requires the following protocol:
-
Send an action
v2/graph.project.fromTripletsThis will initialize the import process and allows us to specify the graph name, and settings likeundirected_relationship_types. It returns a job id, that we need to reference the import job in the following steps. -
Send the data in batches to the Arrow server.
-
Send another action called
v2/graph.project.fromTriplets.doneto tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session. -
Wait for the import process to reach the
DONEstate.
The most complicated step here is to run the actual data upload on each
spark worker. We will use the mapInArrow function to run custom code
on each spark worker. Each worker will receive a number of arrow record
batches that we can directly send to the GDS session’s Arrow server.
The user wants to add a 1-second delay (sleep) within the loop that
waits for the import job to finish. This requires importing the time
module and adding time.sleep(1) inside the while loop at the end of
the cell.
graph-analytics-serverless-spark.ipynb
import time
import pandas as pd
import pyarrow
from pyspark.sql import functions
graph_name = "bike_trips"
arrow_client = gds.arrow_client()
# 1. Start the import process
job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)
# Define a function that receives an arrow batch and uploads it to the GDS session
def upload_batch(iterator):
for batch in iterator:
arrow_client.upload_triplets(job_id, [batch])
yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({"batch_rows_imported": [len(batch)]}))
# Select the source target pairs from our source data
source_target_pairs = spark.sql("""
SELECT start_station_id AS sourceNode, end_station_id AS targetNode
FROM bike_trips
""")
# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.
uploaded_batches = source_target_pairs.mapInArrow(upload_batch, "batch_rows_imported long")
# Aggregate the batch sizes to receive the row count.
aggregated_batch_sizes = uploaded_batches.agg(functions.sum("batch_rows_imported").alias("rows_imported"))
# Show the result. This will trigger the computation and thus run the data upload.
aggregated_batch_sizes.show()
# 3. Finish the import process
arrow_client.triplet_load_done(job_id)
# 4. Wait for the import to finish
while not arrow_client.job_status(job_id).succeeded():
time.sleep(1)
G = gds.v2.graph.get(graph_name)
G
6. Running Algorithms
We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples.
print("Running PageRank ...")
pr_result = gds.v2.page_rank.mutate(G, mutate_property="pagerank")
7. Sending the computation result back to Spark
Once the computation is done, we might want to further use the result in Spark. We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers. Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers. We use a data range equal to the size of workers we have in our cluster as our driving table. On the workers we will disregard the input and instead stream the computation data from the GDS Session.
# 1. Start the node property export on the GDS session
job_id = arrow_client.get_node_properties(G.name(), ["pagerank"])
# Define a function that receives data from the GDS Session and turns it into data batches
def retrieve_data(ignored):
stream_data = arrow_client.stream_job(G.name(), job_id)
batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)
for b in batches:
yield b
# Create DataFrame with a single column and one row per worker
input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF("batch_id")
# 2. Stream the data from the GDS Session into the Spark workers
received_batches = input_partitions.mapInArrow(retrieve_data, "nodeId long, pagerank double")
# Optional: Repartition the data to make sure it is distributed equally
result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)
result.toPandas()