Source code for neo4j_graphrag.experimental.pipeline.config.template_pipeline.simple_kg_builder

#  Copyright (c) "Neo4j"
#  Neo4j Sweden AB [https://neo4j.com]
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#      https://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
from typing import Any, ClassVar, Literal, Optional, Sequence, Union

from pydantic import ConfigDict

from neo4j_graphrag.experimental.components.embedder import TextChunkEmbedder
from neo4j_graphrag.experimental.components.entity_relation_extractor import (
    EntityRelationExtractor,
    LLMEntityRelationExtractor,
    OnError,
)
from neo4j_graphrag.experimental.components.kg_writer import KGWriter, Neo4jWriter
from neo4j_graphrag.experimental.components.pdf_loader import PdfLoader
from neo4j_graphrag.experimental.components.resolver import (
    EntityResolver,
    SinglePropertyExactMatchResolver,
)
from neo4j_graphrag.experimental.components.schema import (
    SchemaBuilder,
    SchemaEntity,
    SchemaRelation,
)
from neo4j_graphrag.experimental.components.text_splitters.base import TextSplitter
from neo4j_graphrag.experimental.components.text_splitters.fixed_size_splitter import (
    FixedSizeSplitter,
)
from neo4j_graphrag.experimental.components.types import (
    LexicalGraphConfig,
    SchemaEnforcementMode,
)
from neo4j_graphrag.experimental.pipeline.config.object_config import ComponentType
from neo4j_graphrag.experimental.pipeline.config.template_pipeline.base import (
    TemplatePipelineConfig,
)
from neo4j_graphrag.experimental.pipeline.config.types import PipelineType
from neo4j_graphrag.experimental.pipeline.exceptions import PipelineDefinitionError
from neo4j_graphrag.experimental.pipeline.types import (
    ConnectionDefinition,
    EntityInputType,
    RelationInputType,
)
from neo4j_graphrag.generation.prompts import ERExtractionTemplate


[docs] class SimpleKGPipelineConfig(TemplatePipelineConfig): COMPONENTS: ClassVar[list[str]] = [ "pdf_loader", "splitter", "chunk_embedder", "schema", "extractor", "writer", "resolver", ] template_: Literal[PipelineType.SIMPLE_KG_PIPELINE] = ( PipelineType.SIMPLE_KG_PIPELINE ) from_pdf: bool = False entities: Sequence[EntityInputType] = [] relations: Sequence[RelationInputType] = [] potential_schema: Optional[list[tuple[str, str, str]]] = None enforce_schema: SchemaEnforcementMode = SchemaEnforcementMode.NONE on_error: OnError = OnError.IGNORE prompt_template: Union[ERExtractionTemplate, str] = ERExtractionTemplate() perform_entity_resolution: bool = True lexical_graph_config: Optional[LexicalGraphConfig] = None neo4j_database: Optional[str] = None pdf_loader: Optional[ComponentType] = None kg_writer: Optional[ComponentType] = None text_splitter: Optional[ComponentType] = None model_config = ConfigDict(arbitrary_types_allowed=True) def _get_pdf_loader(self) -> Optional[PdfLoader]: if not self.from_pdf: return None if self.pdf_loader: return self.pdf_loader.parse(self._global_data) # type: ignore return PdfLoader() def _get_run_params_for_pdf_loader(self) -> dict[str, Any]: if not self.from_pdf: return {} if self.pdf_loader: return self.pdf_loader.get_run_params(self._global_data) return {} def _get_splitter(self) -> TextSplitter: if self.text_splitter: return self.text_splitter.parse(self._global_data) # type: ignore return FixedSizeSplitter() def _get_run_params_for_splitter(self) -> dict[str, Any]: if self.text_splitter: return self.text_splitter.get_run_params(self._global_data) return {} def _get_chunk_embedder(self) -> TextChunkEmbedder: return TextChunkEmbedder(embedder=self.get_default_embedder()) def _get_schema(self) -> SchemaBuilder: return SchemaBuilder() def _get_run_params_for_schema(self) -> dict[str, Any]: return { "entities": [SchemaEntity.from_text_or_dict(e) for e in self.entities], "relations": [SchemaRelation.from_text_or_dict(r) for r in self.relations], "potential_schema": self.potential_schema, } def _get_extractor(self) -> EntityRelationExtractor: return LLMEntityRelationExtractor( llm=self.get_default_llm(), prompt_template=self.prompt_template, enforce_schema=self.enforce_schema, on_error=self.on_error, ) def _get_writer(self) -> KGWriter: if self.kg_writer: return self.kg_writer.parse(self._global_data) # type: ignore return Neo4jWriter( driver=self.get_default_neo4j_driver(), neo4j_database=self.neo4j_database, ) def _get_run_params_for_writer(self) -> dict[str, Any]: if self.kg_writer: return self.kg_writer.get_run_params(self._global_data) return {} def _get_resolver(self) -> Optional[EntityResolver]: if not self.perform_entity_resolution: return None return SinglePropertyExactMatchResolver( driver=self.get_default_neo4j_driver(), neo4j_database=self.neo4j_database, ) def _get_connections(self) -> list[ConnectionDefinition]: connections = [] if self.from_pdf: connections.append( ConnectionDefinition( start="pdf_loader", end="splitter", input_config={"text": "pdf_loader.text"}, ) ) connections.append( ConnectionDefinition( start="schema", end="extractor", input_config={ "schema": "schema", "document_info": "pdf_loader.document_info", }, ) ) else: connections.append( ConnectionDefinition( start="schema", end="extractor", input_config={ "schema": "schema", }, ) ) connections.append( ConnectionDefinition( start="splitter", end="chunk_embedder", input_config={ "text_chunks": "splitter", }, ) ) connections.append( ConnectionDefinition( start="chunk_embedder", end="extractor", input_config={ "chunks": "chunk_embedder", }, ) ) connections.append( ConnectionDefinition( start="extractor", end="writer", input_config={ "graph": "extractor", }, ) ) if self.perform_entity_resolution: connections.append( ConnectionDefinition( start="writer", end="resolver", input_config={}, ) ) return connections def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]: run_params = {} if self.lexical_graph_config: run_params["extractor"] = { "lexical_graph_config": self.lexical_graph_config } text = user_input.get("text") file_path = user_input.get("file_path") if not ((text is None) ^ (file_path is None)): # exactly one of text or user_input must be set raise PipelineDefinitionError( "Use either 'text' (when from_pdf=False) or 'file_path' (when from_pdf=True) argument." ) if self.from_pdf: if not file_path: raise PipelineDefinitionError( "Expected 'file_path' argument when 'from_pdf' is True." ) run_params["pdf_loader"] = {"filepath": file_path} else: if not text: raise PipelineDefinitionError( "Expected 'text' argument when 'from_pdf' is False." ) run_params["splitter"] = {"text": text} return run_params