랭체인 에이전트
- 랭체인 사내문서 RAG 마무리 - 랭체인 Agent 예제
This commit is contained in:
+280
-7
@@ -1,23 +1,35 @@
|
||||
import gradio as gr
|
||||
from langchain_community.document_loaders import PyPDFLoader, CSVLoader, TextLoader, UnstructuredWordDocumentLoader, \
|
||||
Docx2txtLoader, UnstructuredExcelLoader
|
||||
from langchain_community.document_loaders import PyPDFLoader, CSVLoader, TextLoader, UnstructuredWordDocumentLoader, Docx2txtLoader, UnstructuredExcelLoader
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.output_parsers import StrOutputParser
|
||||
from langchain_ibm import WatsonxEmbeddings
|
||||
from langchain_ibm import ChatWatsonx
|
||||
from langchain_ollama import OllamaEmbeddings
|
||||
from pathlib import Path
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
from langchain_chroma import Chroma
|
||||
from langchain_classic.retrievers.self_query.chroma import ChromaTranslator
|
||||
from langchain_classic.retrievers.self_query.base import SelfQueryRetriever
|
||||
from langchain_classic.chains.query_constructor.base import AttributeInfo
|
||||
from langchain_classic.retrievers import EnsembleRetriever, ContextualCompressionRetriever, BM25Retriever
|
||||
from langchain_cohere import CohereRerank
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_classic.memory import ConversationBufferWindowMemory
|
||||
from langchain_classic.chains import ConversationalRetrievalChain
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import pickle
|
||||
|
||||
# TODO : 일단 되게 함 공유되면 코드 비교해서 다른 부분 체크
|
||||
|
||||
# 모델(LLM, Embeddding)
|
||||
# 모델(LLM, Embedding)
|
||||
load_dotenv()
|
||||
|
||||
apikey = os.getenv("WATSONX_API_KEY")
|
||||
project_id = os.getenv("WATSONX_PROJECT_ID")
|
||||
watsonx_ai_url = os.getenv("WATSONX_URL")
|
||||
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
|
||||
|
||||
watson_embedding = WatsonxEmbeddings(
|
||||
model_id="ibm/granite-embedding-278m-multilingual",
|
||||
@@ -26,6 +38,17 @@ watson_embedding = WatsonxEmbeddings(
|
||||
project_id=f"{project_id}"
|
||||
)
|
||||
|
||||
watson_llm = ChatWatsonx(
|
||||
model_id="ibm/granite-4-h-small",
|
||||
url = f"{watsonx_ai_url}",
|
||||
api_key = f"{apikey}",
|
||||
project_id=f"{project_id}",
|
||||
max_tokens = 2000,
|
||||
params = {
|
||||
"temperature":0
|
||||
}
|
||||
)
|
||||
|
||||
ollama_embedding = OllamaEmbeddings(model="nomic-embed-text-v2-moe")
|
||||
|
||||
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
|
||||
@@ -46,6 +69,134 @@ DOCUMENTS = []
|
||||
CHUNKS = []
|
||||
VECTORSTORE = None
|
||||
|
||||
BM25_RETRIEVER = None
|
||||
DENSE_RETRIEVER = None
|
||||
SELFQUERY_RETRIEVER = None
|
||||
FINAL_RETRIEVER = None
|
||||
|
||||
QA_CHAIN = None
|
||||
|
||||
META_FIELDS = [
|
||||
AttributeInfo(name="case_id", description="채용년도", type="integer"),
|
||||
AttributeInfo(name="recruitment_period", description="상반기 또는 하반기", type="string"),
|
||||
AttributeInfo(name="company", description="회사명", type="string"),
|
||||
AttributeInfo(name="document_type", description="직무기술서, 채용공고, 기업분석", type="string"),
|
||||
AttributeInfo(name="file_name", description="파일명", type="string"),
|
||||
]
|
||||
BM25_RETRIEVER = None
|
||||
DENSE_RETRIEVER = None
|
||||
SELFQUERY_RETRIEVER = None
|
||||
FINAL_RETRIEVER = None
|
||||
|
||||
QA_CHAIN = None
|
||||
|
||||
META_FILEDS = [
|
||||
AttributeInfo(name="year", description="채용년도", type="integer"),
|
||||
AttributeInfo(
|
||||
name="recruitment_period", description="상반기 또는 하반기", type="string"
|
||||
),
|
||||
AttributeInfo(name="company", description="회사명", type="string"),
|
||||
AttributeInfo(
|
||||
name="document_type",
|
||||
description="직무기술서, 채용공고, 기업분석",
|
||||
type="string",
|
||||
),
|
||||
AttributeInfo(name="file_name", description="파일명", type="string"),
|
||||
]
|
||||
|
||||
# 대화 메모리
|
||||
memory = ConversationBufferWindowMemory(k=5, memory_key="chat_history", return_messages=True, output_key="answer", input_key="question") # TODO : input_key="question" 이것도 추가했음
|
||||
SYSTEM_PROMPT = """\
|
||||
당신은 회사 내부 문서를 기반으로 직원들의 질문에 답하는 AI 어시스턴트입니다.
|
||||
|
||||
다음 규칙을 반드시 지켜주세요.
|
||||
|
||||
1. 제공된 문서 내용에만 기반하여 답변하세요.
|
||||
2. 문서에 없는 내용은 '해당 내용은 제공된 문서에서 찾을 수 없습니다.' 라고 답하세요.
|
||||
3. 답변 마지막에 참고한 문서명을 명시하세요.
|
||||
4. 한국어로 명확하고 구체적으로 답변하세요.
|
||||
"""
|
||||
QA_PROMPT = ChatPromptTemplate.from_messages([
|
||||
("system", SYSTEM_PROMPT),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("human", """
|
||||
[참고문서]
|
||||
{context}
|
||||
[질문]
|
||||
{question}"""
|
||||
), # TODO : 이것도 원래 query 였는데 question으로 변경했음
|
||||
])
|
||||
|
||||
# =======================
|
||||
# 앱 시작 시
|
||||
# =======================
|
||||
def build_retriever(chunks, save_chunks=False):
|
||||
|
||||
global BM25_RETRIEVER
|
||||
global DENSE_RETRIEVER
|
||||
global SELFQUERY_RETRIEVER
|
||||
global FINAL_RETRIEVER
|
||||
|
||||
# 검색테스트 탭으로 바로 시작한다면
|
||||
|
||||
# BM25 index 작업을 폴더에 저장시키기
|
||||
if save_chunks:
|
||||
with open(CHUNKS_PATH, "wb") as f:
|
||||
pickle.dump(chunks, f)
|
||||
|
||||
# retriever 초기화
|
||||
# BM25 index 는 Chroma 에 저장되지 않음
|
||||
BM25_RETRIEVER = BM25Retriever.from_documents(chunks, k=5)
|
||||
|
||||
# 일반검색
|
||||
DENSE_RETRIEVER = VECTORSTORE.as_retriever(k=20)
|
||||
|
||||
# 셀프쿼리
|
||||
SELFQUERY_RETRIEVER = SelfQueryRetriever.from_llm(
|
||||
llm=watson_llm,
|
||||
vectorstore=VECTORSTORE,
|
||||
document_contents="계열사 직무기술서 문서",
|
||||
metadata_field_info=META_FILEDS,
|
||||
structured_query_translator=ChromaTranslator(),
|
||||
search_kwargs={"k": 20},
|
||||
)
|
||||
|
||||
# final : bm25 + 일반 + rerank
|
||||
ensemble = EnsembleRetriever(
|
||||
retrievers=[BM25_RETRIEVER, DENSE_RETRIEVER], weights=[0.35, 0.65]
|
||||
)
|
||||
reranker = CohereRerank(model="rerank-v4.0-pro", top_n=5)
|
||||
FINAL_RETRIEVER = ContextualCompressionRetriever(
|
||||
base_compressor=reranker, base_retriever=ensemble
|
||||
)
|
||||
|
||||
return "Retriever 생성 완료"
|
||||
|
||||
|
||||
def initialize():
|
||||
global VECTORSTORE
|
||||
|
||||
# db 없는 경우
|
||||
if not Path(CHUNKS_PATH).exists():
|
||||
print("기존 vector 없음")
|
||||
return
|
||||
|
||||
# BM25 제외한 retriever 는 이 부분만 하면 가능
|
||||
# 기존 vectorstore 호출
|
||||
VECTORSTORE = Chroma(
|
||||
persist_directory=CHROMA_DIR,
|
||||
collection_name=COLLECTION_NAME,
|
||||
embedding_function=ollama_embedding,
|
||||
)
|
||||
|
||||
# BM25 Retriever => 파일 로드
|
||||
if Path(CHUNKS_PATH).exists():
|
||||
with open(CHUNKS_PATH, "rb") as f:
|
||||
chunks = pickle.load(f)
|
||||
|
||||
build_retriever(chunks=chunks, save_chunks=False)
|
||||
print("Retriever 로드")
|
||||
|
||||
# ==========
|
||||
# Tap 1 - 기능 구현
|
||||
# ==========
|
||||
@@ -61,6 +212,7 @@ def extract_metadata(file_path):
|
||||
"year": int(datas[0]),
|
||||
"recruitment_period": datas[1] + "반기",
|
||||
"company": datas[2],
|
||||
"document_type" : datas[3],
|
||||
"file_name": name
|
||||
}
|
||||
|
||||
@@ -72,7 +224,24 @@ def upload_files(files):
|
||||
확장자 분리
|
||||
"""
|
||||
global DOCUMENTS
|
||||
global CHUNKS
|
||||
global VECTORSTORE
|
||||
global BM25_RETRIEVER
|
||||
global DENSE_RETRIEVER
|
||||
global SELFQUERY_RETRIEVER
|
||||
global FINAL_RETRIEVER
|
||||
|
||||
# 문서를 새롭게 업로드할 때 기존 내용이 있을 수도 있어 제거
|
||||
BM25_RETRIEVER = None
|
||||
DENSE_RETRIEVER = None
|
||||
SELFQUERY_RETRIEVER = None
|
||||
FINAL_RETRIEVER = None
|
||||
CHUNKS = []
|
||||
VECTORSTORE = None # TODO : 원래 []였는데 None로 바꿈
|
||||
|
||||
# 💡 [Check] 파일이 업로드되지 않고 빈 상태로 버튼을 누른 경우 처리
|
||||
if files is None:
|
||||
return "오류: 업로드할 파일을 먼저 선택해 주세요!"
|
||||
all_docs = []
|
||||
|
||||
for file in files:
|
||||
@@ -131,6 +300,13 @@ def build_vectorstore():
|
||||
collection_name=COLLECTION_NAME
|
||||
)
|
||||
|
||||
# retriever 생성
|
||||
# save_chunks = True : bm25 index 저장
|
||||
build_retriever(CHUNKS, save_chunks=True)
|
||||
|
||||
global QA_CHAIN
|
||||
QA_CHAIN = None
|
||||
|
||||
return f"""
|
||||
생성 완료
|
||||
|
||||
@@ -140,9 +316,99 @@ Vector: {VECTORSTORE._collection.count()}
|
||||
"""
|
||||
|
||||
# ==========
|
||||
# Gradio UI
|
||||
# Tap 2 - 기능 구현
|
||||
# 1. 임베딩 작업 완료
|
||||
# 2. 문서관리 => 검색테스트
|
||||
# ==========
|
||||
def format_docs(docs):
|
||||
"""Document 객체에서 page_content 추출"""
|
||||
|
||||
if not docs:
|
||||
return "검색 결과 없음"
|
||||
|
||||
result = []
|
||||
result.append(f"검색 결과 수 {len(docs)}건\n")
|
||||
for i, d in enumerate(docs[:3], 1):
|
||||
result.append(f"""
|
||||
[문서 {i}]
|
||||
|
||||
회사 : {d.metadata.get("company","-")}
|
||||
유형 : {d.metadata.get("document_type","-")}
|
||||
년도 : {d.metadata.get("year","-")} {d.metadata.get("recruitment_period","-")}
|
||||
출처 : {d.metadata.get("file_name","-")}
|
||||
|
||||
{d.page_content[:100]}
|
||||
""")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def search_test(query):
|
||||
if FINAL_RETRIEVER is None:
|
||||
return (
|
||||
"BM25 retriever 미생성",
|
||||
"Dense retriever 미생성",
|
||||
"SelfQuery retriever 미생성",
|
||||
"Final retriever 미생성",
|
||||
)
|
||||
|
||||
# 각각의 retriever 결과 추출(Document)한 후
|
||||
# format_docs() return
|
||||
bm25_docs = format_docs(BM25_RETRIEVER.invoke(query))
|
||||
dense_docs = format_docs(DENSE_RETRIEVER.invoke(query))
|
||||
self_docs = format_docs(SELFQUERY_RETRIEVER.invoke(query))
|
||||
final_docs = format_docs(FINAL_RETRIEVER.invoke(query))
|
||||
|
||||
return bm25_docs, dense_docs, self_docs, final_docs
|
||||
|
||||
# ==========
|
||||
# Tab 3 - 기능 구현
|
||||
# ChatInterface
|
||||
# - history : 대화이력관리
|
||||
# RunnableWithMessageHistory
|
||||
# ==========
|
||||
|
||||
def create_chain():
|
||||
global QA_CHAIN
|
||||
|
||||
if QA_CHAIN is None:
|
||||
QA_CHAIN = ConversationalRetrievalChain.from_llm(
|
||||
llm = watson_llm,
|
||||
retriever = FINAL_RETRIEVER,
|
||||
memory = memory,
|
||||
combine_docs_chain_kwargs = {"prompt": QA_PROMPT},
|
||||
get_chat_history=lambda h: h, # TODO : 이거 맞음?
|
||||
return_source_documents = True,
|
||||
)
|
||||
|
||||
return QA_CHAIN
|
||||
|
||||
def chat(message, history):
|
||||
global QA_CHAIN
|
||||
|
||||
if FINAL_RETRIEVER is None:
|
||||
return "먼저 vector DB를 생성하세요"
|
||||
|
||||
QA_CHAIN = create_chain()
|
||||
response = QA_CHAIN.invoke({"question": message})
|
||||
|
||||
answer = response["answer"]
|
||||
|
||||
sources = []
|
||||
|
||||
for doc in response['source_documents']:
|
||||
sources.append(
|
||||
f"{doc.metadata.get('company', '-')} - "
|
||||
f"{doc.metadata.get('file_name', '-')}"
|
||||
)
|
||||
|
||||
answer += "\n\n[참고문서]\n"
|
||||
answer += "\n".join(list(set(sources)))
|
||||
return answer
|
||||
|
||||
# ==========
|
||||
# Gradio UI
|
||||
# ==========
|
||||
with gr.Blocks() as app:
|
||||
gr.Markdown("# 사내 문서 RAG")
|
||||
with gr.Tab("문서관리"):
|
||||
@@ -158,12 +424,19 @@ with gr.Blocks() as app:
|
||||
vector_btn.click(build_vectorstore, outputs = vector_status)
|
||||
|
||||
with gr.Tab("검색 테스트"):
|
||||
pass
|
||||
query = gr.Textbox(label = "검색어")
|
||||
search_btn = gr.Button("검색")
|
||||
bm25_box = gr.Textbox(label="BM25") # 키워드
|
||||
dense_box = gr.Textbox(label = "Dense") # 일반 검색
|
||||
self_box = gr.Textbox(label = "Self") # selfquery
|
||||
rerank_box = gr.Textbox(label = "Final")
|
||||
|
||||
search_btn.click(search_test, query, outputs=[bm25_box, dense_box, self_box, rerank_box])
|
||||
|
||||
with gr.Tab("RAG 채팅"):
|
||||
pass
|
||||
gr.ChatInterface(chat)
|
||||
|
||||
pass
|
||||
|
||||
if __name__ =="__main__":
|
||||
initialize()
|
||||
app.launch()
|
||||
Reference in New Issue
Block a user