Source code for neo4j_graphrag.experimental.components.entity_relation_extractor

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from __future__ import annotations

import asyncio
import enum
import json
import logging
from typing import Any, List, Optional, Union, cast, Dict

import json_repair
from pydantic import ValidationError, validate_call

from neo4j_graphrag.exceptions import LLMGenerationError
from neo4j_graphrag.experimental.components.lexical_graph import LexicalGraphBuilder
from neo4j_graphrag.experimental.components.schema import SchemaConfig
from neo4j_graphrag.experimental.components.types import (
    DocumentInfo,
    LexicalGraphConfig,
    Neo4jGraph,
    Neo4jNode,
    Neo4jRelationship,
    TextChunk,
    TextChunks,
    SchemaEnforcementMode,
)
from neo4j_graphrag.experimental.pipeline.component import Component
from neo4j_graphrag.experimental.pipeline.exceptions import InvalidJSONError
from neo4j_graphrag.generation.prompts import ERExtractionTemplate, PromptTemplate
from neo4j_graphrag.llm import LLMInterface
from neo4j_graphrag.utils.logging import prettify

logger = logging.getLogger(__name__)


class OnError(enum.Enum):
    RAISE = "RAISE"
    IGNORE = "IGNORE"

    @classmethod
    def possible_values(cls) -> List[str]:
        return [e.value for e in cls]


def balance_curly_braces(json_string: str) -> str:
    """
    Balances curly braces `{}` in a JSON string. This function ensures that every opening brace has a corresponding
    closing brace, but only when they are not part of a string value. If there are unbalanced closing braces,
    they are ignored. If there are missing closing braces, they are appended at the end of the string.

    Args:
        json_string (str): A potentially malformed JSON string with unbalanced curly braces.

    Returns:
        str: A JSON string with balanced curly braces.
    """
    stack = []
    fixed_json = []
    in_string = False
    escape = False

    for char in json_string:
        if char == '"' and not escape:
            in_string = not in_string
        elif char == "\\" and in_string:
            escape = not escape
            fixed_json.append(char)
            continue
        else:
            escape = False

        if not in_string:
            if char == "{":
                stack.append(char)
                fixed_json.append(char)
            elif char == "}" and stack and stack[-1] == "{":
                stack.pop()
                fixed_json.append(char)
            elif char == "}" and (not stack or stack[-1] != "{"):
                continue
            else:
                fixed_json.append(char)
        else:
            fixed_json.append(char)

    # If stack is not empty, add missing closing braces
    while stack:
        stack.pop()
        fixed_json.append("}")

    return "".join(fixed_json)


def fix_invalid_json(raw_json: str) -> str:
    repaired_json = json_repair.repair_json(raw_json)
    repaired_json = cast(str, repaired_json).strip()

    if repaired_json == '""':
        raise InvalidJSONError("JSON repair resulted in an empty or invalid JSON.")
    if not repaired_json:
        raise InvalidJSONError("JSON repair resulted in an empty string.")
    return repaired_json


[docs] class EntityRelationExtractor(Component): """Abstract class for entity relation extraction components. Args: on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. """ def __init__( self, *args: Any, on_error: OnError = OnError.IGNORE, create_lexical_graph: bool = True, **kwargs: Any, ) -> None: self.on_error = on_error self.create_lexical_graph = create_lexical_graph
[docs] async def run( self, chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, **kwargs: Any, ) -> Neo4jGraph: raise NotImplementedError()
[docs] def update_ids( self, graph: Neo4jGraph, chunk: TextChunk, ) -> Neo4jGraph: """Make node IDs unique across chunks, document and pipeline runs by prefixing them with a unique prefix. """ prefix = f"{chunk.chunk_id}" for node in graph.nodes: node.id = f"{prefix}:{node.id}" if node.properties is None: node.properties = {} node.properties.update({"chunk_index": chunk.index}) for rel in graph.relationships: rel.start_node_id = f"{prefix}:{rel.start_node_id}" rel.end_node_id = f"{prefix}:{rel.end_node_id}" return graph
[docs] class LLMEntityRelationExtractor(EntityRelationExtractor): """ Extracts a knowledge graph from a series of text chunks using a large language model. Args: llm (LLMInterface): The language model to use for extraction. prompt_template (ERExtractionTemplate | str): A custom prompt template to use for extraction. create_lexical_graph (bool): Whether to include the text chunks in the graph in addition to the extracted entities and relations. Defaults to True. enforce_schema (SchemaEnforcementMode): Whether to validate or not the extracted entities/rels against the provided schema. Defaults to None. on_error (OnError): What to do when an error occurs during extraction. Defaults to raising an error. max_concurrency (int): The maximum number of concurrent tasks which can be used to make requests to the LLM. Example: .. code-block:: python from neo4j_graphrag.experimental.components.entity_relation_extractor import LLMEntityRelationExtractor from neo4j_graphrag.llm import OpenAILLM from neo4j_graphrag.experimental.pipeline import Pipeline llm = OpenAILLM(model_name="gpt-4o", model_params={"temperature": 0, "response_format": {"type": "object"}}) extractor = LLMEntityRelationExtractor(llm=llm) pipe = Pipeline() pipe.add_component(extractor, "extractor") """ def __init__( self, llm: LLMInterface, prompt_template: ERExtractionTemplate | str = ERExtractionTemplate(), create_lexical_graph: bool = True, enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE, on_error: OnError = OnError.RAISE, max_concurrency: int = 5, ) -> None: super().__init__(on_error=on_error, create_lexical_graph=create_lexical_graph) self.llm = llm # with response_format={ "type": "json_object" }, self.enforce_schema = enforce_schema self.max_concurrency = max_concurrency if isinstance(prompt_template, str): template = PromptTemplate(prompt_template, expected_inputs=[]) else: template = prompt_template self.prompt_template = template async def extract_for_chunk( self, schema: SchemaConfig, examples: str, chunk: TextChunk ) -> Neo4jGraph: """Run entity extraction for a given text chunk.""" prompt = self.prompt_template.format( text=chunk.text, schema=schema.model_dump(), examples=examples ) llm_result = await self.llm.ainvoke(prompt) try: llm_generated_json = fix_invalid_json(llm_result.content) result = json.loads(llm_generated_json) except (json.JSONDecodeError, InvalidJSONError) as e: if self.on_error == OnError.RAISE: raise LLMGenerationError("LLM response is not valid JSON") from e else: logger.error( f"LLM response is not valid JSON for chunk_index={chunk.index}" ) logger.debug(f"Invalid JSON: {llm_result.content}") result = {"nodes": [], "relationships": []} try: chunk_graph = Neo4jGraph.model_validate(result) except ValidationError as e: if self.on_error == OnError.RAISE: raise LLMGenerationError("LLM response has improper format") from e else: logger.error( f"LLM response has improper format for chunk_index={chunk.index}" ) logger.debug(f"Invalid JSON format: {result}") chunk_graph = Neo4jGraph() return chunk_graph async def post_process_chunk( self, chunk_graph: Neo4jGraph, chunk: TextChunk, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> None: """Perform post-processing after entity and relation extraction: - Update node IDs to make them unique across chunks - Build the lexical graph if requested """ self.update_ids(chunk_graph, chunk) if lexical_graph_builder: await lexical_graph_builder.process_chunk_extracted_entities( chunk_graph, chunk, ) def combine_chunk_graphs( self, lexical_graph: Optional[Neo4jGraph], chunk_graphs: List[Neo4jGraph] ) -> Neo4jGraph: """Combine sub-graphs obtained for each chunk into a single Neo4jGraph object""" if lexical_graph: graph = lexical_graph.model_copy(deep=True) else: graph = Neo4jGraph() for chunk_graph in chunk_graphs: graph.nodes.extend(chunk_graph.nodes) graph.relationships.extend(chunk_graph.relationships) return graph async def run_for_chunk( self, sem: asyncio.Semaphore, chunk: TextChunk, schema: SchemaConfig, examples: str, lexical_graph_builder: Optional[LexicalGraphBuilder] = None, ) -> Neo4jGraph: """Run extraction, validation and post processing for a single chunk""" async with sem: chunk_graph = await self.extract_for_chunk(schema, examples, chunk) final_chunk_graph = self.validate_chunk(chunk_graph, schema) await self.post_process_chunk( final_chunk_graph, chunk, lexical_graph_builder, ) return final_chunk_graph
[docs] @validate_call async def run( self, chunks: TextChunks, document_info: Optional[DocumentInfo] = None, lexical_graph_config: Optional[LexicalGraphConfig] = None, schema: Union[SchemaConfig, None] = None, examples: str = "", **kwargs: Any, ) -> Neo4jGraph: """Perform entity and relation extraction for all chunks in a list. Optionally, creates the "lexical graph" by adding nodes and relationships to represent the document and its chunks in the returned graph (For more details, see the :ref:`Lexical Graph Builder doc <lexical-graph-builder>` and the :ref:`User Guide <lexical-graph-in-er-extraction>`) Args: chunks (TextChunks): List of text chunks to extract entities and relations from. document_info (Optional[DocumentInfo], optional): Document the chunks are coming from. Used in the lexical graph creation step. lexical_graph_config (Optional[LexicalGraphConfig], optional): Lexical graph configuration to customize node labels and relationship types in the lexical graph. schema (SchemaConfig | None): Definition of the schema to guide the LLM in its extraction. examples (str): Examples for few-shot learning in the prompt. """ lexical_graph_builder = None lexical_graph = None if self.create_lexical_graph: config = lexical_graph_config or LexicalGraphConfig() lexical_graph_builder = LexicalGraphBuilder(config=config) lexical_graph_result = await lexical_graph_builder.run( text_chunks=chunks, document_info=document_info ) lexical_graph = lexical_graph_result.graph elif lexical_graph_config: lexical_graph_builder = LexicalGraphBuilder(config=lexical_graph_config) schema = schema or SchemaConfig(entities={}, relations={}, potential_schema=[]) examples = examples or "" sem = asyncio.Semaphore(self.max_concurrency) tasks = [ self.run_for_chunk( sem, chunk, schema, examples, lexical_graph_builder, ) for chunk in chunks.chunks ] chunk_graphs: list[Neo4jGraph] = list(await asyncio.gather(*tasks)) graph = self.combine_chunk_graphs(lexical_graph, chunk_graphs) logger.debug(f"Extracted graph: {prettify(graph)}") return graph
def validate_chunk( self, chunk_graph: Neo4jGraph, schema: SchemaConfig ) -> Neo4jGraph: """ Perform validation after entity and relation extraction: - Enforce schema if schema enforcement mode is on and schema is provided """ if self.enforce_schema != SchemaEnforcementMode.NONE: if not schema or not schema.entities: # schema is not provided logger.warning( "Schema enforcement is ON but the guiding schema is not provided." ) else: # if enforcing_schema is on and schema is provided, clean the graph return self._clean_graph(chunk_graph, schema) return chunk_graph def _clean_graph( self, graph: Neo4jGraph, schema: SchemaConfig, ) -> Neo4jGraph: """ Verify that the graph conforms to the provided schema. Remove invalid entities,relationships, and properties. If an entity is removed, all of its relationships are also removed. If no valid properties remain for an entity, remove that entity. """ # enforce nodes (remove invalid labels, strip invalid properties) filtered_nodes = self._enforce_nodes(graph.nodes, schema) # enforce relationships (remove those referencing invalid nodes or with invalid # types or with start/end nodes not conforming to the schema, and strip invalid # properties) filtered_rels = self._enforce_relationships( graph.relationships, filtered_nodes, schema ) return Neo4jGraph(nodes=filtered_nodes, relationships=filtered_rels) def _enforce_nodes( self, extracted_nodes: List[Neo4jNode], schema: SchemaConfig ) -> List[Neo4jNode]: """ Filter extracted nodes to be conformant to the schema. Keep only those whose label is in schema. For each valid node, filter out properties not present in the schema. Remove a node if it ends up with no valid properties. """ if self.enforce_schema != SchemaEnforcementMode.STRICT: return extracted_nodes valid_nodes = [] for node in extracted_nodes: schema_entity = schema.entities.get(node.label) if not schema_entity: continue allowed_props = schema_entity.get("properties", []) filtered_props = self._enforce_properties(node.properties, allowed_props) if filtered_props: valid_nodes.append( Neo4jNode( id=node.id, label=node.label, properties=filtered_props, embedding_properties=node.embedding_properties, ) ) return valid_nodes def _enforce_relationships( self, extracted_relationships: List[Neo4jRelationship], filtered_nodes: List[Neo4jNode], schema: SchemaConfig, ) -> List[Neo4jRelationship]: """ Filter extracted nodes to be conformant to the schema. Keep only those whose types are in schema, start/end node conform to schema, and start/end nodes are in filtered nodes (i.e., kept after node enforcement). For each valid relationship, filter out properties not present in the schema. If a relationship direct is incorrect, invert it. """ if self.enforce_schema != SchemaEnforcementMode.STRICT: return extracted_relationships valid_rels = [] valid_nodes = {node.id: node.label for node in filtered_nodes} potential_schema = schema.potential_schema for rel in extracted_relationships: schema_relation = ( schema.relations.get(rel.type) if schema.relations else None ) if not schema_relation: continue if ( rel.start_node_id not in valid_nodes or rel.end_node_id not in valid_nodes ): continue start_label = valid_nodes[rel.start_node_id] end_label = valid_nodes[rel.end_node_id] tuple_valid = True if potential_schema: tuple_valid = (start_label, rel.type, end_label) in potential_schema reverse_tuple_valid = ( end_label, rel.type, start_label, ) in potential_schema if not tuple_valid and not reverse_tuple_valid: continue allowed_props = schema_relation.get("properties", []) filtered_props = self._enforce_properties(rel.properties, allowed_props) valid_rels.append( Neo4jRelationship( start_node_id=rel.start_node_id if tuple_valid else rel.end_node_id, end_node_id=rel.end_node_id if tuple_valid else rel.start_node_id, type=rel.type, properties=filtered_props, embedding_properties=rel.embedding_properties, ) ) return valid_rels def _enforce_properties( self, properties: Dict[str, Any], valid_properties: List[Dict[str, Any]] ) -> Dict[str, Any]: """ Filter properties. Keep only those that exist in schema (i.e., valid properties). """ valid_prop_names = {prop["name"] for prop in valid_properties} return { key: value for key, value in properties.items() if key in valid_prop_names }