Files
Source/ollama/rag_company.py
T
cooney 30a049c5e1 랭체인 에이전트 심화
- RouterChain, RunnableBranch, SequentialChain
2026-06-05 17:58:54 +09:00

440 lines
13 KiB
Python

import gradio as gr
from langchain_community.document_loaders import PyPDFLoader, CSVLoader, TextLoader, UnstructuredWordDocumentLoader, Docx2txtLoader, UnstructuredExcelLoader
from dotenv import load_dotenv
from langchain_core.output_parsers import StrOutputParser
from langchain_ibm import WatsonxEmbeddings
from langchain_ibm import ChatWatsonx
from langchain_ollama import OllamaEmbeddings
from pathlib import Path
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_chroma import Chroma
from langchain_classic.retrievers.self_query.chroma import ChromaTranslator
from langchain_classic.retrievers.self_query.base import SelfQueryRetriever
from langchain_classic.chains.query_constructor.base import AttributeInfo
from langchain_classic.retrievers import EnsembleRetriever, ContextualCompressionRetriever, BM25Retriever
from langchain_cohere import CohereRerank
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_classic.memory import ConversationBufferWindowMemory
from langchain_classic.chains import ConversationalRetrievalChain
import os
import shutil
import pickle
# 모델(LLM, Embedding)
load_dotenv()
apikey = os.getenv("WATSONX_API_KEY")
project_id = os.getenv("WATSONX_PROJECT_ID")
watsonx_ai_url = os.getenv("WATSONX_URL")
COHERE_API_KEY = os.getenv("COHERE_API_KEY")
watson_embedding = WatsonxEmbeddings(
model_id="ibm/granite-embedding-278m-multilingual",
url = f"{watsonx_ai_url}",
api_key = f"{apikey}",
project_id=f"{project_id}"
)
watson_llm = ChatWatsonx(
model_id="ibm/granite-4-h-small",
url = f"{watsonx_ai_url}",
api_key = f"{apikey}",
project_id=f"{project_id}",
max_tokens = 2000,
params = {
"temperature":0
}
)
ollama_embedding = OllamaEmbeddings(model="nomic-embed-text-v2-moe")
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
LOADERS = {
".pdf" : PyPDFLoader,
".csv" : CSVLoader,
".docx" : UnstructuredWordDocumentLoader,
".xlsx" : UnstructuredExcelLoader,
".txt" : TextLoader,
}
CHROMA_DIR = "./db/chroma"
COLLECTION_NAME = "job_rag"
CHUNKS_PATH = "./db/chunks.pkl"
DOCUMENTS = []
CHUNKS = []
VECTORSTORE = None
BM25_RETRIEVER = None
DENSE_RETRIEVER = None
SELFQUERY_RETRIEVER = None
FINAL_RETRIEVER = None
QA_CHAIN = None
META_FIELDS = [
AttributeInfo(name="case_id", description="채용년도", type="integer"),
AttributeInfo(name="recruitment_period", description="상반기 또는 하반기", type="string"),
AttributeInfo(name="company", description="회사명", type="string"),
AttributeInfo(name="document_type", description="직무기술서, 채용공고, 기업분석", type="string"),
AttributeInfo(name="file_name", description="파일명", type="string"),
]
BM25_RETRIEVER = None
DENSE_RETRIEVER = None
SELFQUERY_RETRIEVER = None
FINAL_RETRIEVER = None
QA_CHAIN = None
META_FILEDS = [
AttributeInfo(name="year", description="채용년도", type="integer"),
AttributeInfo(
name="recruitment_period", description="상반기 또는 하반기", type="string"
),
AttributeInfo(name="company", description="회사명", type="string"),
AttributeInfo(
name="document_type",
description="직무기술서, 채용공고, 기업분석",
type="string",
),
AttributeInfo(name="file_name", description="파일명", type="string"),
]
# 대화 메모리
memory = ConversationBufferWindowMemory(k=5, memory_key="chat_history", return_messages=True, output_key="answer", input_key="question")
SYSTEM_PROMPT = """\
당신은 회사 내부 문서를 기반으로 직원들의 질문에 답하는 AI 어시스턴트입니다.
다음 규칙을 반드시 지켜주세요.
1. 제공된 문서 내용에만 기반하여 답변하세요.
2. 문서에 없는 내용은 '해당 내용은 제공된 문서에서 찾을 수 없습니다.' 라고 답하세요.
3. 답변 마지막에 참고한 문서명을 명시하세요.
4. 한국어로 명확하고 구체적으로 답변하세요.
"""
QA_PROMPT = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
MessagesPlaceholder(variable_name="chat_history"),
("human", """
[참고문서]
{context}
[질문]
{question}"""
),
])
# =======================
# 앱 시작 시
# =======================
def build_retriever(chunks, save_chunks=False):
global BM25_RETRIEVER
global DENSE_RETRIEVER
global SELFQUERY_RETRIEVER
global FINAL_RETRIEVER
# 검색테스트 탭으로 바로 시작한다면
# BM25 index 작업을 폴더에 저장시키기
if save_chunks:
with open(CHUNKS_PATH, "wb") as f:
pickle.dump(chunks, f)
# retriever 초기화
# BM25 index 는 Chroma 에 저장되지 않음
BM25_RETRIEVER = BM25Retriever.from_documents(chunks, k=5)
# 일반검색
DENSE_RETRIEVER = VECTORSTORE.as_retriever(k=20)
# 셀프쿼리
SELFQUERY_RETRIEVER = SelfQueryRetriever.from_llm(
llm=watson_llm,
vectorstore=VECTORSTORE,
document_contents="계열사 직무기술서 문서",
metadata_field_info=META_FILEDS,
structured_query_translator=ChromaTranslator(),
search_kwargs={"k": 20},
)
# final : bm25 + 일반 + rerank
ensemble = EnsembleRetriever(
retrievers=[BM25_RETRIEVER, DENSE_RETRIEVER], weights=[0.35, 0.65]
)
reranker = CohereRerank(model="rerank-v4.0-pro", top_n=5)
FINAL_RETRIEVER = ContextualCompressionRetriever(
base_compressor=reranker, base_retriever=ensemble
)
return "Retriever 생성 완료"
def initialize():
global VECTORSTORE
# db 없는 경우
if not Path(CHUNKS_PATH).exists():
print("기존 vector 없음")
return
# BM25 제외한 retriever 는 이 부분만 하면 가능
# 기존 vectorstore 호출
VECTORSTORE = Chroma(
persist_directory=CHROMA_DIR,
collection_name=COLLECTION_NAME,
embedding_function=ollama_embedding,
)
# BM25 Retriever => 파일 로드
if Path(CHUNKS_PATH).exists():
with open(CHUNKS_PATH, "rb") as f:
chunks = pickle.load(f)
build_retriever(chunks=chunks, save_chunks=False)
print("Retriever 로드")
# ==========
# Tap 1 - 기능 구현
# ==========
def extract_metadata(file_path):
# 2026 상 삼성 E&A 직무기술서
# {year:2026, recruitment_period:상반기, company:삼성E&A, file_name:2026 상 삼성E&A 직무기술서}
# 확장자를 제외한 파일명
name = file_path.name
datas = name.split()
return {
"year": int(datas[0]),
"recruitment_period": datas[1] + "반기",
"company": datas[2],
"document_type" : datas[3],
"file_name": name
}
def upload_files(files):
"""
여러 개의 파일(pdf, csv)이 업로드 될 때 각 파일을 load() 한 결과는 DOCUMENTS 추가
몇 개의 문서가 업로드 되었는지 리턴
확장자 분리
"""
global DOCUMENTS
global CHUNKS
global VECTORSTORE
global BM25_RETRIEVER
global DENSE_RETRIEVER
global SELFQUERY_RETRIEVER
global FINAL_RETRIEVER
# 문서를 새롭게 업로드할 때 기존 내용이 있을 수도 있어 제거
BM25_RETRIEVER = None
DENSE_RETRIEVER = None
SELFQUERY_RETRIEVER = None
FINAL_RETRIEVER = None
CHUNKS = []
VECTORSTORE = []
# 💡 [Check] 파일이 업로드되지 않고 빈 상태로 버튼을 누른 경우 처리
if files is None:
return "오류: 업로드할 파일을 먼저 선택해 주세요!"
all_docs = []
for file in files:
# 파일명 가져오기
path = Path(file.name)
# 확장자 가져오기
ext = path.suffix.lower()
loader = LOADERS[ext](file.name)
docs = loader.load()
# metadata 정리
meta_info = extract_metadata(path)
# metadata 업데이트
for doc in docs:
doc.metadata.update(meta_info)
all_docs.extend(docs)
DOCUMENTS = all_docs
return f"문서 수 : {len(all_docs)}"
def preview_chunks():
global DOCUMENTS
global CHUNKS
if not DOCUMENTS:
return "문서가 없음."
# 전체문서는 DOCUMENTS 안에 있음
# 분리
CHUNKS = splitter.split_documents(DOCUMENTS)
# 청크 10개 까지만 내용 출력
preview = []
for i, chunk in enumerate(CHUNKS[:10]):
preview.append(f"""[CHUNK {i + 1}]{chunk.page_content[:100]}\n
""")
return "\n\n".join(preview)
def build_vectorstore():
global VECTORSTORE
global CHUNKS
if not CHUNKS:
return "먼저 CHUNK를 생성하세요."
# 기존의 VECTORSTORE가 있다면 제거
if Path(CHROMA_DIR).exists():
shutil.rmtree(CHROMA_DIR)
VECTORSTORE = Chroma.from_documents(documents=CHUNKS,
embedding=watson_embedding,
persist_directory=CHROMA_DIR,
collection_name=COLLECTION_NAME
)
# retriever 생성
# save_chunks = True : bm25 index 저장
build_retriever(CHUNKS, save_chunks=True)
global QA_CHAIN
QA_CHAIN = None
return f"""
생성 완료
Chunk: {len(CHUNKS)}
Vector: {VECTORSTORE._collection.count()}
"""
# ==========
# Tap 2 - 기능 구현
# 1. 임베딩 작업 완료
# 2. 문서관리 => 검색테스트
# ==========
def format_docs(docs):
"""Document 객체에서 page_content 추출"""
if not docs:
return "검색 결과 없음"
result = []
result.append(f"검색 결과 수 {len(docs)}\n")
for i, d in enumerate(docs[:3], 1):
result.append(f"""
[문서 {i}]
회사 : {d.metadata.get("company","-")}
유형 : {d.metadata.get("document_type","-")}
년도 : {d.metadata.get("year","-")} {d.metadata.get("recruitment_period","-")}
출처 : {d.metadata.get("file_name","-")}
{d.page_content[:100]}
""")
return "\n".join(result)
def search_test(query):
if FINAL_RETRIEVER is None:
return (
"BM25 retriever 미생성",
"Dense retriever 미생성",
"SelfQuery retriever 미생성",
"Final retriever 미생성",
)
# 각각의 retriever 결과 추출(Document)한 후
# format_docs() return
bm25_docs = format_docs(BM25_RETRIEVER.invoke(query))
dense_docs = format_docs(DENSE_RETRIEVER.invoke(query))
self_docs = format_docs(SELFQUERY_RETRIEVER.invoke(query))
final_docs = format_docs(FINAL_RETRIEVER.invoke(query))
return bm25_docs, dense_docs, self_docs, final_docs
# ==========
# Tab 3 - 기능 구현
# ChatInterface
# - history : 대화이력관리
# RunnableWithMessageHistory
# ==========
def create_chain():
global QA_CHAIN
if QA_CHAIN is None:
QA_CHAIN = ConversationalRetrievalChain.from_llm(
llm = watson_llm,
retriever = FINAL_RETRIEVER,
memory = memory,
combine_docs_chain_kwargs = {"prompt": QA_PROMPT},
get_chat_history=lambda h: h,
return_source_documents = True,
)
return QA_CHAIN
def chat(message, history):
global QA_CHAIN
if FINAL_RETRIEVER is None:
return "먼저 vector DB를 생성하세요"
QA_CHAIN = create_chain()
response = QA_CHAIN.invoke({"question": message})
answer = response["answer"]
sources = []
for doc in response['source_documents']:
sources.append(
f"{doc.metadata.get('company', '-')} - "
f"{doc.metadata.get('file_name', '-')}"
)
answer += "\n\n[참고문서]\n"
answer += "\n".join(list(set(sources)))
return answer
# ==========
# Gradio UI
# ==========
with gr.Blocks() as app:
gr.Markdown("# 사내 문서 RAG")
with gr.Tab("문서관리"):
files = gr.File(file_count = "multiple")
upload_btn = gr.Button("문서 업로드")
upload_status = gr.Textbox()
upload_btn.click(upload_files, files, upload_status)
chunk_btn = gr.Button("chunk 확인")
chunk_preview = gr.Textbox(lines = 20)
chunk_btn.click(preview_chunks, outputs = chunk_preview)
vector_btn = gr.Button("vector DB 생성")
vector_status = gr.Textbox()
vector_btn.click(build_vectorstore, outputs = vector_status)
with gr.Tab("검색 테스트"):
query = gr.Textbox(label = "검색어")
search_btn = gr.Button("검색")
bm25_box = gr.Textbox(label="BM25") # 키워드
dense_box = gr.Textbox(label = "Dense") # 일반 검색
self_box = gr.Textbox(label = "Self") # selfquery
rerank_box = gr.Textbox(label = "Final")
search_btn.click(search_test, query, outputs=[bm25_box, dense_box, self_box, rerank_box])
with gr.Tab("RAG 채팅"):
gr.ChatInterface(chat)
if __name__ =="__main__":
initialize()
app.launch()