from langchain_ibm import ChatWatsonx from langchain_core.prompts import ( PromptTemplate, ChatPromptTemplate, MessagesPlaceholder, ) from langchain_core.output_parsers import ( StrOutputParser, JsonOutputParser, PydanticOutputParser, ) from langchain_core.runnables import ( RunnablePassthrough, RunnableParallel, RunnableLambda, ) from langchain_core.messages import HumanMessage, SystemMessage, AIMessage from langchain_core.chat_history import ( InMemoryChatMessageHistory, BaseChatMessageHistory, ) from langchain_core.runnables.history import RunnableWithMessageHistory from pydantic import BaseModel, Field from typing import Literal from dotenv import load_dotenv import os import gradio as gr from langchain_community.document_loaders import ( PyPDFLoader, CSVLoader, WebBaseLoader, DirectoryLoader, ) from youtube_transcript_api import YouTubeTranscriptApi from langchain_core.documents import Document from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_ollama import OllamaEmbeddings from langchain_ibm import WatsonxEmbeddings from langchain_chroma import Chroma from langchain_community.vectorstores import FAISS from langgraph.graph import StateGraph, START, END from typing import TypedDict, List # 모델(LLM, Embeddding) load_dotenv() apikey = os.getenv("WATSONX_API_KEY") project_id = os.getenv("WATSONX_PROJECT_ID") watsonx_ai_url = os.getenv("WATSONX_URL") ollama_embedding = OllamaEmbeddings(model="nomic-embed-text-v2-moe") watson_embedding = WatsonxEmbeddings( model_id="ibm/granite-embedding-278m-multilingual", url=f"{watsonx_ai_url}", api_key=f"{apikey}", project_id=f"{project_id}" ) watson_llm = ChatWatsonx( model_id="ibm/granite-4-h-small", url=f"{watsonx_ai_url}", api_key=f"{apikey}", project_id=f"{project_id}", max_tokens=2000, params={"temperature": 0}, ) # 기존 존재하는 db 접근 시 # vectorstore = Chroma(collection_name="research", embedding_function=watson_embedding, persist_directory="./db/chroma_db") # 1. State 정의 class RagState(TypedDict): query : str retrieved_docs : list[Document] answer : str # node : def retrieve(state): vectorstore = Chroma(collection_name="docs", embedding_function=watson_embedding, persist_directory="./db/chroma_db") docs = vectorstore.similarity_search(state['query'], k=3) return {"retrieved_docs" : docs} def generate(state): context = "\n\n".join(doc.page_content for doc in state['retrieved_docs']) prompt = """\ 다음 컨텍스트를 참고하여 질문에 답하세요. 컨텍스트에 없는 내용은 모른다고 답하세요. 컨텍스트: {context} 질문: {query} """ response = watson_llm.invoke(prompt.format(context=context, query=state['query'])) return {"answer":response.content} # 일반 함수 # 1단계 def process_pdf(pdf_file): """ pdf 로드, 분할, 벡터스토어 저장 반환 : 청크 개수 리턴 """ if pdf_file is None: return ("PDF 파일을 업로드 해주세요.") # PDF 로드 loader = PyPDFLoader(pdf_file) docs = loader.load() splitter = RecursiveCharacterTextSplitter(chunk_size=300, chunk_overlap=30) chunks = splitter.split_documents(docs) # 총 chunk 수 total_chunks = len(chunks) # 기존 db 존재한다면 컬렉션 제거 vectorstore = Chroma(collection_name="docs", embedding_function=watson_embedding, persist_directory="./db/chroma_db") vectorstore.delete_collection() # 새로운 벡터스토어 생성 Chroma.from_documents( chunks, watson_embedding, collection_name="docs", persist_directory="./db/chroma_db" ) return (f"총 페이지 수 : {total_chunks}") # 일반 함수 # 2단계 def rag_chat(query): """ invoke() => result['answer'] 리턴 """ result = app.invoke({"query": query}) return result['answer'] # 그래프 구성 graph = StateGraph(RagState) graph.add_node("retrieve", retrieve) graph.add_node("generate", generate) graph.add_edge(START, "retrieve") graph.add_edge("retrieve", "generate") graph.add_edge("generate", END) app = graph.compile() with gr.Blocks() as demo: gr.Markdown("# PDF RAG 학습 앱") with gr.Tabs(): with gr.Tab("LCEL RAG -> LangGraph RAG 변환"): # 파일 업로드 컴포넌트 pdf_input = gr.File(label="PDF 업로드", file_types=[".pdf"]) btn1 = gr.Button("분석 시작") # textbox 5개 output = gr.Textbox(label="처리결과") btn1.click( fn=process_pdf, inputs=[pdf_input], outputs=[output], ) question_input = gr.Textbox(label="질문 입력") run_btn1 = gr.Button("질문하기") answer_output = gr.Textbox(label="최종 답변", lines=10) run_btn1.click( fn=rag_chat, inputs=[question_input], outputs=[answer_output], ) demo.launch()