[Issue]: cannot import name 'store_entity_semantic_embeddings'
Do you need to file an issue?
- [ ] I have searched the existing issues and this bug is not already filed.
- [ ] My model is hosted on OpenAI or Azure. If not, please look at the "model providers" issue and don't file a new one here.
- [ ] I believe this is a legitimate bug, not just a question. If this is a question, please use the Discussions area.
Describe the issue
I copied the code from https://github.com/win4r/GraphRAG4OpenWebUI to my local machine (the project uses GraphRAG version 0.3.3), but my GraphRAG is the latest version 1.2.0. The code throws an error: cannot import name 'store_entity_semantic_embeddings' from 'graphrag.query.input.loaders.dfs'. How can I resolve this? here is the code below: import os import asyncio import time import uuid import json import re import pandas as pd import tiktoken import logging from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any, Union from contextlib import asynccontextmanager from tavily import TavilyClient
GraphRAG 相关导入
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey from graphrag.query.indexer_adapters import ( read_indexer_covariates, read_indexer_entities, read_indexer_relationships, read_indexer_reports, read_indexer_text_units, ) from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings from graphrag.query.llm.oai.chat_openai import ChatOpenAI from graphrag.query.llm.oai.embedding import OpenAIEmbedding from graphrag.query.llm.oai.typing import OpenaiApiType from graphrag.query.question_gen.local_gen import LocalQuestionGen from graphrag.query.structured_search.local_search.mixed_context import LocalSearchMixedContext from graphrag.query.structured_search.local_search.search import LocalSearch from graphrag.query.structured_search.global_search.community_context import GlobalCommunityContext from graphrag.query.structured_search.global_search.search import GlobalSearch from graphrag.vector_stores.lancedb import LanceDBVectorStore
设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(name)
设置常量和配置
INPUT_DIR = os.getenv('INPUT_DIR') LANCEDB_URI = f"{INPUT_DIR}/lancedb" COMMUNITY_REPORT_TABLE = "create_final_community_reports" ENTITY_TABLE = "create_final_nodes" ENTITY_EMBEDDING_TABLE = "create_final_entities" RELATIONSHIP_TABLE = "create_final_relationships" COVARIATE_TABLE = "create_final_covariates" TEXT_UNIT_TABLE = "create_final_text_units" COMMUNITY_LEVEL = 2 PORT = 8012
全局变量,用于存储搜索引擎和问题生成器
local_search_engine = None global_search_engine = None question_generator = None
数据模型
class Message(BaseModel): role: str content: str
class ChatCompletionRequest(BaseModel): model: str messages: List[Message] temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 stream: Optional[bool] = False stop: Optional[Union[str, List[str]]] = None max_tokens: Optional[int] = None presence_penalty: Optional[float] = 0 frequency_penalty: Optional[float] = 0 logit_bias: Optional[Dict[str, float]] = None user: Optional[str] = None
class ChatCompletionResponseChoice(BaseModel): index: int message: Message finish_reason: Optional[str] = None
class Usage(BaseModel): prompt_tokens: int completion_tokens: int total_tokens: int
class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) model: str choices: List[ChatCompletionResponseChoice] usage: Usage system_fingerprint: Optional[str] = None
async def setup_llm_and_embedder(): """ 设置语言模型(LLM)和嵌入模型 """ logger.info("正在设置LLM和嵌入器")
# 获取API密钥和基础URL
api_key = os.environ.get("GRAPHRAG_API_KEY", "YOUR_API_KEY")
api_key_embedding = os.environ.get("GRAPHRAG_API_KEY_EMBEDDING", api_key)
api_base = os.environ.get("API_BASE", "https://api.openai.com/v1")
api_base_embedding = os.environ.get("API_BASE_EMBEDDING", "https://api.openai.com/v1")
# 获取模型名称
llm_model = os.environ.get("GRAPHRAG_LLM_MODEL", "gpt-3.5-turbo-0125")
embedding_model = os.environ.get("GRAPHRAG_EMBEDDING_MODEL", "text-embedding-3-small")
# 检查API密钥是否存在
if api_key == "YOUR_API_KEY":
logger.error("环境变量中未找到有效的GRAPHRAG_API_KEY")
raise ValueError("GRAPHRAG_API_KEY未正确设置")
# 初始化ChatOpenAI实例
llm = ChatOpenAI(
api_key=api_key,
api_base=api_base,
model=llm_model,
api_type=OpenaiApiType.OpenAI,
max_retries=20,
)
# 初始化token编码器
token_encoder = tiktoken.get_encoding("cl100k_base")
# 初始化文本嵌入模型
text_embedder = OpenAIEmbedding(
api_key=api_key_embedding,
api_base=api_base_embedding,
api_type=OpenaiApiType.OpenAI,
model=embedding_model,
deployment_name=embedding_model,
max_retries=20,
)
logger.info("LLM和嵌入器设置完成")
return llm, token_encoder, text_embedder
async def load_context(): """ 加载上下文数据,包括实体、关系、报告、文本单元和协变量 """ logger.info("正在加载上下文数据") try: entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet") entity_embedding_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet") entities = read_indexer_entities(entity_df, entity_embedding_df, COMMUNITY_LEVEL)
description_embedding_store = LanceDBVectorStore(collection_name="entity_description_embeddings")
description_embedding_store.connect(db_uri=LANCEDB_URI)
store_entity_semantic_embeddings(entities=entities, vectorstore=description_embedding_store)
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
covariate_df = pd.read_parquet(f"{INPUT_DIR}/{COVARIATE_TABLE}.parquet")
claims = read_indexer_covariates(covariate_df)
logger.info(f"声明记录数: {len(claims)}")
covariates = {"claims": claims}
logger.info("上下文数据加载完成")
return entities, relationships, reports, text_units, description_embedding_store, covariates
except Exception as e:
logger.error(f"加载上下文数据时出错: {str(e)}")
raise
async def setup_search_engines(llm, token_encoder, text_embedder, entities, relationships, reports, text_units, description_embedding_store, covariates): """ 设置本地搜索引擎和全局搜索引擎 """ logger.info("正在设置搜索引擎")
# 设置本地搜索引擎
local_context_builder = LocalSearchMixedContext(
community_reports=reports,
text_units=text_units,
entities=entities,
relationships=relationships,
covariates=covariates,
entity_text_embeddings=description_embedding_store,
embedding_vectorstore_key=EntityVectorStoreKey.ID,
text_embedder=text_embedder,
token_encoder=token_encoder,
)
local_context_params = {
"text_unit_prop": 0.5,
"community_prop": 0.1,
"conversation_history_max_turns": 5,
"conversation_history_user_turns_only": True,
"top_k_mapped_entities": 10,
"top_k_relationships": 10,
"include_entity_rank": True,
"include_relationship_weight": True,
"include_community_rank": False,
"return_candidate_context": False,
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
"max_tokens": 12_000,
}
local_llm_params = {
"max_tokens": 2_000,
"temperature": 0.0,
}
local_search_engine = LocalSearch(
llm=llm,
context_builder=local_context_builder,
token_encoder=token_encoder,
llm_params=local_llm_params,
context_builder_params=local_context_params,
response_type="multiple paragraphs",
)
# 设置全局搜索引擎
global_context_builder = GlobalCommunityContext(
community_reports=reports,
entities=entities,
token_encoder=token_encoder,
)
global_context_builder_params = {
"use_community_summary": False,
"shuffle_data": True,
"include_community_rank": True,
"min_community_rank": 0,
"community_rank_name": "rank",
"include_community_weight": True,
"community_weight_name": "occurrence weight",
"normalize_community_weight": True,
"max_tokens": 12_000,
"context_name": "Reports",
}
map_llm_params = {
"max_tokens": 1000,
"temperature": 0.0,
"response_format": {"type": "json_object"},
}
reduce_llm_params = {
"max_tokens": 2000,
"temperature": 0.0,
}
global_search_engine = GlobalSearch(
llm=llm,
context_builder=global_context_builder,
token_encoder=token_encoder,
max_data_tokens=12_000,
map_llm_params=map_llm_params,
reduce_llm_params=reduce_llm_params,
allow_general_knowledge=False,
json_mode=True,
context_builder_params=global_context_builder_params,
concurrent_coroutines=32,
response_type="multiple paragraphs",
)
logger.info("搜索引擎设置完成")
return local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params
def format_response(response): """ 格式化响应,添加适当的换行和段落分隔。 """ paragraphs = re.split(r'\n{2,}', response)
formatted_paragraphs = []
for para in paragraphs:
if '```' in para:
parts = para.split('```')
for i, part in enumerate(parts):
if i % 2 == 1: # 这是代码块
parts[i] = f"\n```\n{part.strip()}\n```\n"
para = ''.join(parts)
else:
para = para.replace('. ', '.\n')
formatted_paragraphs.append(para.strip())
return '\n\n'.join(formatted_paragraphs)
async def tavily_search(prompt: str): """ 使用Tavily API进行搜索 """ try: client = TavilyClient(api_key=os.environ['TAVILY_API_KEY']) resp = client.search(prompt, search_depth="advanced")
# 将Tavily响应转换为Markdown格式
markdown_response = "# 搜索结果\n\n"
for result in resp.get('results', []):
markdown_response += f"## [{result['title']}]({result['url']})\n\n"
markdown_response += f"{result['content']}\n\n"
return markdown_response
except Exception as e:
raise HTTPException(status_code=500, detail=f"Tavily搜索错误: {str(e)}")
@asynccontextmanager async def lifespan(app: FastAPI): # 启动时执行 global local_search_engine, global_search_engine, question_generator try: logger.info("正在初始化搜索引擎和问题生成器...") llm, token_encoder, text_embedder = await setup_llm_and_embedder() entities, relationships, reports, text_units, description_embedding_store, covariates = await load_context() local_search_engine, global_search_engine, local_context_builder, local_llm_params, local_context_params = await setup_search_engines( llm, token_encoder, text_embedder, entities, relationships, reports, text_units, description_embedding_store, covariates )
question_generator = LocalQuestionGen(
llm=llm,
context_builder=local_context_builder,
token_encoder=token_encoder,
llm_params=local_llm_params,
context_builder_params=local_context_params,
)
logger.info("初始化完成。")
except Exception as e:
logger.error(f"初始化过程中出错: {str(e)}")
raise
yield
# 关闭时执行
logger.info("正在关闭...")
app = FastAPI(lifespan=lifespan)
在 chat_completions 函数中添加以下代码
async def full_model_search(prompt: str): """ 执行全模型搜索,包括本地检索、全局检索和 Tavily 搜索 """ local_result = await local_search_engine.asearch(prompt) global_result = await global_search_engine.asearch(prompt) tavily_result = await tavily_search(prompt)
# 格式化结果
formatted_result = "# 🔥🔥🔥综合搜索结果\n\n"
formatted_result += "## 🔥🔥🔥本地检索结果\n"
formatted_result += format_response(local_result.response) + "\n\n"
formatted_result += "## 🔥🔥🔥全局检索结果\n"
formatted_result += format_response(global_result.response) + "\n\n"
formatted_result += "## 🔥🔥🔥Tavily 搜索结果\n"
formatted_result += tavily_result + "\n\n"
return formatted_result
@app.post("/v1/chat/completions") async def chat_completions(request: ChatCompletionRequest): if not local_search_engine or not global_search_engine: logger.error("搜索引擎未初始化") raise HTTPException(status_code=500, detail="搜索引擎未初始化")
try:
logger.info(f"收到聊天完成请求: {request}")
prompt = request.messages[-1].content
logger.info(f"处理提示: {prompt}")
# 根据模型选择使用不同的搜索方法
if request.model == "graphrag-global-search:latest":
result = await global_search_engine.asearch(prompt)
formatted_response = format_response(result.response)
elif request.model == "tavily-search:latest":
result = await tavily_search(prompt)
formatted_response = result
elif request.model == "full-model:latest":
formatted_response = await full_model_search(prompt)
else: # 默认使用本地搜索
result = await local_search_engine.asearch(prompt)
formatted_response = format_response(result.response)
logger.info(f"格式化的搜索结果: {formatted_response}")
# 流式响应和非流式响应的处理保持不变
if request.stream:
async def generate_stream():
chunk_id = f"chatcmpl-{uuid.uuid4().hex}"
lines = formatted_response.split('\n')
for i, line in enumerate(lines):
chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"delta": {"content": line + '\n'}, # if i > 0 else {"role": "assistant", "content": ""},
"finish_reason": None
}
]
}
yield f"data: {json.dumps(chunk)}\n\n"
await asyncio.sleep(0.05)
final_chunk = {
"id": chunk_id,
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": request.model,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
}
yield f"data: {json.dumps(final_chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
response = ChatCompletionResponse(
model=request.model,
choices=[
ChatCompletionResponseChoice(
index=0,
message=Message(role="assistant", content=formatted_response),
finish_reason="stop"
)
],
usage=Usage(
prompt_tokens=len(prompt.split()),
completion_tokens=len(formatted_response.split()),
total_tokens=len(prompt.split()) + len(formatted_response.split())
)
)
logger.info(f"发送响应: {response}")
return JSONResponse(content=response.dict())
except Exception as e:
logger.error(f"处理聊天完成时出错: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models") async def list_models(): """ 返回可用模型列表 """ logger.info("收到模型列表请求") current_time = int(time.time()) models = [ {"id": "graphrag-local-search:latest", "object": "model", "created": current_time - 100000, "owned_by": "graphrag"}, {"id": "graphrag-global-search:latest", "object": "model", "created": current_time - 95000, "owned_by": "graphrag"}, # {"id": "graphrag-question-generator:latest", "object": "model", "created": current_time - 90000, "owned_by": "graphrag"}, # {"id": "gpt-3.5-turbo:latest", "object": "model", "created": current_time - 80000, "owned_by": "openai"}, # {"id": "text-embedding-3-small:latest", "object": "model", "created": current_time - 70000, "owned_by": "openai"}, {"id": "tavily-search:latest", "object": "model", "created": current_time - 85000, "owned_by": "tavily"}, {"id": "full-model:latest", "object": "model", "created": current_time - 80000, "owned_by": "combined"}
]
response = {
"object": "list",
"data": models
}
logger.info(f"发送模型列表: {response}")
return JSONResponse(content=response)
if name == "main": import uvicorn
logger.info(f"在端口 {PORT} 上启动服务器")
uvicorn.run(app, host="0.0.0.0", port=PORT)
Steps to reproduce
No response
GraphRAG Config Used
# Paste your config here
Logs and screenshots
No response
Additional Information
- GraphRAG Version:
- Operating System:
- Python Version:
- Related Issues:
Additional Information
GraphRAG Version: 1.2.0
Operating System: windows 10
Python Version: 3.12
I'm having the same issue too, is this still ongoing?
I am having the same issue
Version: 2.1.0