從0開始學(xué)RAG之RAG-Fusion
原文地址:https://zhuanlan.zhihu.com/p/684994205
類型: 技術(shù)分享
本文為 @lucas大叔 投稿原創(chuàng)轉(zhuǎn)載!如有侵權(quán),麻煩告知?jiǎng)h除!
基本原理
RAG-Fusion可以認(rèn)為是MultiQueryRetriever的進(jìn)化版,如下圖所示,RAG-Fusion首先根據(jù)原始question從不同角度生成多個(gè)版本的新question,用以提升question的質(zhì)量;然后針對(duì)每個(gè)question進(jìn)行向量檢索,到此步為止都是MultiQueryRetriever的功能;與之不同的是,RAG-Fusion在喂給LLM生成答案之前增加了一個(gè)排序的步驟。

排序包含兩個(gè)動(dòng)作,一是獨(dú)立對(duì)每個(gè)question檢索返回的內(nèi)容根據(jù)相似度排序,確定每個(gè)返回chunk在各自候選集中的位置,相似度越高排名越靠前。二是對(duì)所有question 返回的內(nèi)容利用RRF(Reciprocal Rank Fusion)綜合排序,RRF排序原理如下圖所示。

RRF score的計(jì)算公式非常簡(jiǎn)單:

其中,rank是按照距離排序的文檔在各自集合中的排名,k是常數(shù)平滑因子,一般取k=60。RRF將不同檢索器的結(jié)果綜合評(píng)估得到每個(gè)chunk的統(tǒng)一得分。
代碼實(shí)踐
首先,導(dǎo)入必要的packages
import torch
from langchain.load import dumps, loads
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain import PromptTemplate
from langchain import HuggingFacePipeline
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains.llm import LLMChain
from langchain_core.output_parsers import StrOutputParser然后做一些準(zhǔn)備工作,編寫加載embedding模型、LLM和文件的函數(shù)。
def load_embedding(embed_path):
embeddings = HuggingFaceEmbeddings(
model_name=embed_path,
model_kwargs={"device": "cuda"},
encode_kwargs={"normalize_embeddings": True},
)
return embeddings
def load_llm(model_path):
tokenizer = AutoTokenizer.from_pretrained(
model_path,
device_map="auto",
trust_remote_code=True,
torch_dtype=torch.float16
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch.float16,
trust_remote_code=True,
device_map="auto",
)
model_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
return_full_text=True,
)
llm = HuggingFacePipeline(pipeline=model_pipeline)
return llm
def load_data(data_file):
loader = PyPDFLoader(data_file)
documents = loader.load_and_split()
text_splitter = RecursiveCharacterTextSplitter(separators=["。"], chunk_size=512, chunk_overlap=32)
texts_chunks = text_splitter.split_documents(documents)
return texts_chunks編寫retrieval_and_rank()函數(shù),對(duì)query利用similarity_search_with_score()在向量庫中檢索,返回候選chunk和相應(yīng)的score。對(duì)每個(gè)query返回的內(nèi)容,根據(jù)score降序排列,最終得到list[list]結(jié)構(gòu)的候選答案集。
def retrieval_and_rank(queries):
all_results = {}
for query in queries:
if query:
search_results = vectorstore.similarity_search_with_score(query)
results = []
for res in search_results:
content = res[0].page_content
score = res[1]
results.append((content, score))
all_results[query] = results
document_ranks = []
for query, doc_score_list in all_results.items():
ranking_list = [doc for doc, _ in sorted(doc_score_list, key=lambda x: x[1], reverse=True)]
document_ranks.append(ranking_list)
return document_ranks編寫reciprocal_rank_fusion()函數(shù),利用上面的RRF score計(jì)算公式計(jì)算每個(gè)候選chunk的融合得分。
def reciprocal_rank_fusion(document_ranks, k=60):
fused_scores = {}
for docs in document_ranks:
for rank, doc in enumerate(docs):
doc_str = dumps(doc)
if doc_str not in fused_scores:
fused_scores[doc_str] = 0
fused_scores[doc_str] += 1 / (rank + k)
reranked_results = [
(loads(doc), score)
for doc, score in sorted(fused_scores.items(), key=lambda x: x[1], reverse=True)
]
return reranked_results編寫main()函數(shù)。指定模型和文件路徑,依次加載embedding模型、LLM和文件。創(chuàng)建向量庫vectorstore,用于后續(xù)檢索。
基于prompt模板,利用LLM為query生成3個(gè)相關(guān)問題。典型實(shí)現(xiàn)是把生成的相關(guān)問題作為輸入調(diào)用retrieval_and_rank()函數(shù),也可以連同原始query一起扔給retrieval_and_rank()函數(shù),視實(shí)際場(chǎng)景效果驗(yàn)證是否增加原始query。內(nèi)部排序過的結(jié)果用RRF算法再統(tǒng)一計(jì)算得分排序。
接下來是常規(guī)操作,將RRF排序后的內(nèi)容作為上下文,調(diào)用LLMChain驗(yàn)證回答的效果。
if __name__ == "__main__":
data_file = "../data/中華人民共和國證券法(2019修訂).pdf"
model_path = "/data/models/Baichuan2-13B-Chat"
embed_path = "/data/models/bge-large-zh-v1.5"
embeddings = load_embedding(embed_path)
llm = load_llm(model_path)
docs = load_data(data_file)
vectorstore = Chroma.from_documents(docs, embeddings)
template = """You are a helpful assistant that generates multiple search queries based on a single input query. \n
Generate multiple search queries related to: {question} \n
Output (3 queries):"""
prompt_rag_fusion = PromptTemplate.from_template(template)
generate_query_chain = (
prompt_rag_fusion
| llm
| StrOutputParser()
| (lambda x: x.split("\n"))
)
query = "公司首次公開發(fā)行新股,應(yīng)當(dāng)符合哪些條件?"
queries = generate_query_chain.invoke({"question": query})
all_results = retrieval_and_rank(queries)
reranked_results = reciprocal_rank_fusion(all_results)
# ----------------- 構(gòu)造提示模板 ----------------- #
template = """你是一名智能助手,根據(jù)上下文回答用戶的問題,不需要回答額外的信息或捏造事實(shí)。
已知內(nèi)容:
{context}
問題:
{question}
"""
prompt = PromptTemplate(template=template, input_variables=["context", "question"])
# ----------------- 驗(yàn)證效果 ----------------- #
chain = LLMChain(llm=llm, prompt=prompt)
result = chain.run(context=reranked_results, question=query)
print(result)寫在最后
RAG-Fusion雖然可以通過生成多個(gè)相關(guān)query改善query質(zhì)量,個(gè)人理解對(duì)缺少上下文信息或關(guān)鍵元素的短query作用比較大,對(duì)表述相對(duì)完整的query實(shí)屬雞肋。反而可能會(huì)因?yàn)楦膶憅uery檢索到其他角度的信息,造成答案無效信息增多顯得冗長(zhǎng),甚至得出錯(cuò)誤的答案。此時(shí),加入原始query會(huì)對(duì)生成答案的正確性有一定改善。
RAG各種技巧、方法層出不窮,每種技巧有其適用的場(chǎng)景,需要根據(jù)實(shí)際應(yīng)用場(chǎng)景靈活選擇、大膽魔改!?
