Aura Graph Analytics with Spark

Open In Colab

This Jupyter notebook is hosted here in the Neo4j Graph Data Science Client Github repository.

The notebook shows how to use the graphdatascience Python library to create, manage, and use a GDS Session from within an Apache Spark cluster.

We consider a graph of bicycle rentals, which we’re using as a simple example to show how to project data from Spark to a GDS Session, run algorithms, and eventually return results back to Spark. In this notebook we will focus on the interaction with Apache Spark, and will not cover all possible actions using GDS sessions. We refer to other Tutorials for additional details.

1. Prerequisites

We also need to have the graphdatascience Python library installed, version 1.18 or later, as well as pyspark.

%pip install "graphdatascience>=1.18" python-dotenv "pyspark[sql]"
from dotenv import load_dotenv

# This allows to load required secrets from `.env` file in local directory
# This can include Aura API Credentials and Database Credentials.
# If file does not exist this is a noop.
load_dotenv("sessions.env")

1.1. Connecting to a Spark Session

To interact with the Spark cluster we need to first instantiate a Spark session. In this example we will use a local Spark session, which will run Spark on the same machine. Working with a remote Spark cluster will work similarly. For more information about setting up pyspark visit https://spark.apache.org/docs/latest/api/python/getting_started/

from pyspark.sql import SparkSession

spark = SparkSession.builder.master("local[4]").appName("GraphAnalytics").getOrCreate()

# Enable Arrow-based columnar data transfers
spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", "true")

2. Aura API credentials

The entry point for managing GDS Sessions is the GdsSessions object, which requires creating Aura API credentials.

import os

from graphdatascience.session import AuraAPICredentials, GdsSessions

# you can also use AuraAPICredentials.from_env() to load credentials from environment variables
api_credentials = AuraAPICredentials(
    client_id=os.environ["CLIENT_ID"],
    client_secret=os.environ["CLIENT_SECRET"],
    # If your account is a member of several projects, you must also specify the project ID to use
    project_id=os.environ.get("PROJECT_ID", None),
)

sessions = GdsSessions(api_credentials=api_credentials)

3. Creating a new session

A new session is created by calling sessions.get_or_create() with the following parameters:

  • A session name, which lets you reconnect to an existing session by calling get_or_create again.

  • The session memory.

  • The cloud location.

  • A time-to-live (TTL), which ensures that the session is automatically deleted after being unused for the set time, to avoid incurring costs.

See the API reference documentation or the manual for more details on the parameters.

from datetime import timedelta

from graphdatascience.session import CloudLocation, SessionMemory

# Create a GDS session!
gds = sessions.get_or_create(
    # we give it a representative name
    session_name="bike_trips",
    memory=SessionMemory.m_2GB,
    ttl=timedelta(minutes=30),
    cloud_location=CloudLocation("gcp", "europe-west1"),
)
# Verify the connectivity. Hints towards TLS or firewall issues if this fails directly after get_or_create
gds.verify_connectivity()

4. Adding a dataset

As the next step we will setup a dataset in Spark. In this example we will use the New York Bike trips dataset (https://www.kaggle.com/datasets/gabrielramos87/bike-trips). The bike trips form a graph where nodes represent bike renting stations and relationships represent start and end points for a bike rental trip.

import io
import os
import zipfile

import requests

download_path = "bike_trips_data"
if not os.path.exists(download_path):
    url = "https://www.kaggle.com/api/v1/datasets/download/gabrielramos87/bike-trips"

    response = requests.get(url)
    response.raise_for_status()

    # Unzip the content
    with zipfile.ZipFile(io.BytesIO(response.content)) as z:
        z.extractall(download_path)

df = spark.read.csv(download_path, header=True, inferSchema=True)
df.createOrReplaceTempView("bike_trips")
df.limit(10).show()

5. Projecting Graphs

Now that we have our dataset available within our Spark session it is time to project it to the GDS Session.

We first need to get access to the GDSArrowClient. This client allows us to directly communicate with the Arrow Flight server provided by the session.

Our input data already resembles triplets, where each row represents an edge from a source station to a target station. This allows us to use the Arrow Server’s "graph import from triplets" functionality, which requires the following protocol:

  1. Send an action v2/graph.project.fromTriplets This will initialize the import process and allows us to specify the graph name, and settings like undirected_relationship_types. It returns a job id, that we need to reference the import job in the following steps.

  2. Send the data in batches to the Arrow server.

  3. Send another action called v2/graph.project.fromTriplets.done to tell the import process that no more data will be sent. This will trigger the final graph creation inside the GDS session.

  4. Wait for the import process to reach the DONE state.

The most complicated step here is to run the actual data upload on each spark worker. We will use the mapInArrow function to run custom code on each spark worker. Each worker will receive a number of arrow record batches that we can directly send to the GDS session’s Arrow server.

The user wants to add a 1-second delay (sleep) within the loop that waits for the import job to finish. This requires importing the time module and adding time.sleep(1) inside the while loop at the end of the cell.

graph-analytics-serverless-spark.ipynb

import time

import pandas as pd
import pyarrow
from pyspark.sql import functions

graph_name = "bike_trips"

arrow_client = gds.arrow_client()

# 1. Start the import process
job_id = arrow_client.create_graph_from_triplets(graph_name, concurrency=4)


# Define a function that receives an arrow batch and uploads it to the GDS session
def upload_batch(iterator):
    for batch in iterator:
        arrow_client.upload_triplets(job_id, [batch])
        yield pyarrow.RecordBatch.from_pandas(pd.DataFrame({"batch_rows_imported": [len(batch)]}))


# Select the source target pairs from our source data
source_target_pairs = spark.sql("""
                                SELECT start_station_id AS sourceNode, end_station_id AS targetNode
                                FROM bike_trips
                                """)

# 2. Use the `mapInArrow` function to upload the data to the GDS session. Returns a DataFrame with a single column containing the batch sizes.
uploaded_batches = source_target_pairs.mapInArrow(upload_batch, "batch_rows_imported long")

# Aggregate the batch sizes to receive the row count.
aggregated_batch_sizes = uploaded_batches.agg(functions.sum("batch_rows_imported").alias("rows_imported"))

# Show the result. This will trigger the computation and thus run the data upload.
aggregated_batch_sizes.show()

# 3. Finish the import process
arrow_client.triplet_load_done(job_id)

# 4. Wait for the import to finish
while not arrow_client.job_status(job_id).succeeded():
    time.sleep(1)

G = gds.v2.graph.get(graph_name)
G

6. Running Algorithms

We can run algorithms on the constructed graph using the standard GDS Python Client API. See the other tutorials for more examples.

print("Running PageRank ...")
pr_result = gds.v2.page_rank.mutate(G, mutate_property="pagerank")

7. Sending the computation result back to Spark

Once the computation is done, we might want to further use the result in Spark. We can do this in a similar way to the projection, by streaming batches of data into each of the Spark workers. Retrieving the data is a bit more complicated since we need some input DataFrame in order to trigger computations on the Spark workers. We use a data range equal to the size of workers we have in our cluster as our driving table. On the workers we will disregard the input and instead stream the computation data from the GDS Session.

# 1. Start the node property export on the GDS session
job_id = arrow_client.get_node_properties(G.name(), ["pagerank"])


# Define a function that receives data from the GDS Session and turns it into data batches
def retrieve_data(ignored):
    stream_data = arrow_client.stream_job(G.name(), job_id)
    batches = pyarrow.Table.from_pandas(stream_data).to_batches(1000)
    for b in batches:
        yield b


# Create DataFrame with a single column and one row per worker
input_partitions = spark.range(spark.sparkContext.defaultParallelism).toDF("batch_id")
# 2. Stream the data from the GDS Session into the Spark workers
received_batches = input_partitions.mapInArrow(retrieve_data, "nodeId long, pagerank double")
# Optional: Repartition the data to make sure it is distributed equally
result = received_batches.repartition(numPartitions=spark.sparkContext.defaultParallelism)

result.toPandas()

8. Cleanup

Now that we have finished our analysis, we can delete the GDS session and stop the Spark session.

Deleting the GDS session will release all resources associated with it, and stop incurring costs.

gds.delete()
spark.stop()