Create embeddings with open source libraries

The Python library SentenceTransformers provides pre-trained models to generate embeddings for text and images, and allows you to play with embeddings without needing an account on OpenAI or other proprietary services.

This page assumes you have already imported the recommendations dataset and set up your environment, and shows how to generate and store embeddings for Movie nodes basing on their title and plot.

Embeddings are always generated outside of Neo4j, but stored in the Neo4j database.

Setup environment

As a last setup step, install the sentence-transformers package.

pip install sentence-transformers

Create embeddings for movies

The example below fetches all Movie nodes from the database, generates an embedding for title and plot, and adds that as an extra embedding property to each node.

from sentence_transformers import SentenceTransformer
import neo4j


URI = '<URI for Neo4j database>'
AUTH = ('<Username>', '<Password>')
DB_NAME = '<Database name>'  # examples: 'recommendations-50', 'neo4j'


def main():
    driver = neo4j.GraphDatabase.driver(URI, auth=AUTH)  (1)
    driver.verify_connectivity()

    model = SentenceTransformer('all-MiniLM-L6-v2')  # vector size 384  (2)

    batch_size = 100
    batch_n = 1
    movies_with_embeddings = []
    with driver.session(database=DB_NAME) as session:
        # Fetch `Movie` nodes
        result = session.run('MATCH (m:Movie) RETURN m.plot AS plot, m.title AS title')
        for record in result:
            title = record.get('title')
            plot = record.get('plot')

            # Create embedding for title and plot
            if title is not None and plot is not None:
                movies_with_embeddings.append({
                    'title': title,
                    'plot': plot,
                    'embedding': model.encode(f'''  (3)
                        Title: {title}\n
                        Plot: {plot}
                    '''),
                })

            # Import when a batch of movies has embeddings ready; flush buffer
            if len(movies_with_embeddings) == batch_size:  (4)
                import_batch(driver, movies_with_embeddings, batch_n)
                movies_with_embeddings = []
                batch_n += 1

        # Flush last batch
        import_batch(driver, movies_with_embeddings, batch_n)

    # Import complete, show counters
    records, _, _ = driver.execute_query('''
    MATCH (m:Movie WHERE m.embedding IS NOT NULL)
    RETURN count(*) AS countMoviesWithEmbeddings, size(m.embedding) AS embeddingSize
    ''', database_=DB_NAME)
    print(f"""
Embeddings generated and attached to nodes.
Movie nodes with embeddings: {records[0].get('countMoviesWithEmbeddings')}.
Embedding size: {records[0].get('embeddingSize')}.
    """)


def import_batch(driver, nodes_with_embeddings, batch_n):
    # Add embeddings to Movie nodes
    driver.execute_query('''  (5)
    UNWIND $movies as movie
    MATCH (m:Movie {title: movie.title, plot: movie.plot})
    CALL db.create.setNodeVectorProperty(m, 'embedding', movie.embedding)
    ''', movies=nodes_with_embeddings, database_=DB_NAME)
    print(f'Processed batch {batch_n}.')


if __name__ == '__main__':
    main()

'''
Movie nodes with embeddings: 9083.
Embedding size: 384.
'''
1 The driver object is the interface to interact with your Neo4j instance. For more information, see Build applications with Neo4j and Python.
2 The model all-MiniLM-L6-V2 maps text into vectors of size 384 (i.e. lists of 384 numbers).
3 The .encode() method generates an embedding for the given string (title and plot together, in this case).
4 A number of embeddings are collected before a whole batch is submitted to the database. This avoids holding the whole dataset into memory and potential timeouts (especially relevant for larger datasets).
5 The import query sets a new embedding property on each node m, with the embedding vector movie.embedding as value. It uses the Cypher procedure db.create.setNodeVectorProperty, which stores vector properties more efficiently than if they were added with the SET Cypher clause. To set vector properties on relationships, use db.create.setRelationshipVectorProperty.

Once embeddings are in the database, you can use them to compare how similar one movie is to another.