Graph Expansion
Enhances any base retriever (i.e BM25) by discovering multi-hop connections in a single step
Gist Memory
Maintains important context across retrieval steps, simulating the working memory in humans
Enhances any base retriever (i.e BM25) by discovering multi-hop connections in a single step
Maintains important context across retrieval steps, simulating the working memory in humans
GeAR starts by employing a base retriever (like BM25) to fetch passages. Then, it employs a novel graph expansion technique which we call SyncGE to discover multi-hop content. During graph expansion, GeAR starts by using a language model to extract important information from the retrieved passages. We represent such important information in the form of triples, and use those triples to start exploring the graph. This helps bridge passages that may be several reasoning steps (or hops) apart. With these new passages, we use a language model to maintain a "gist memory" that accumulates key information across multiple retrieval steps, similar to how the human brain's hippocampus stores important memories. We then reason with such gist memory do determine whether we need additional information. If we do, we determine the next query to search for, and if we don’t we merge the information in the gist memory with the other retrieved passages, and return the combination of both.
State-of-the-art Retrieval and QA Performance across challenging multi-hop datasets
Solves multi-hop problems using fewer tokens than previous approaches
Below we present pseudocode for both diverse triple beam search (used for graph expansion), and the full GeAR pipeline. Please note that although the code is presented in a python-esque way, it is still pseudocode.
class GistMemory:
"""
Manages the accumulation and storage of proximal triples
across multiple retrieval steps.
"""
def __init__(self):
self.proximal_triples = []
def add_triples(self, new_triples):
"""Append new proximal triples to memory"""
self.proximal_triples.extend(new_triples)
def get_all_triples(self):
"""Return all accumulated triples"""
return self.proximal_triples
def GeAR(query, max_steps=4):
"""
GeAR pipeline implementation with multi-step retrieval capabilities.
Parameters:
query: Original input query
max_steps: Maximum number of retrieval steps
"""
# Initialize variables
gist_memory = GistMemory()
current_query = query
step = 1
retrieved_passages = []
while step <= max_steps:
# Base retrieval for current query (i.e bm25)
base_passages = base_retriever(current_query)
# Read passages and extract proximal triples via LLM
if step == 1:
proximal_triples = reader(base_passages, query)
else:
proximal_triples = reader(base_passages, query, gist_memory.get_all_triples())
# Link proximal triples to their closest real triples in index
triples = tripleLink(proximal_triples)
# Graph expansion using proximal triples
expanded_passages = graph_expasion(triples, query)
# Combine base and expanded passages, and save them
combined_passages = rrf(base_passages + expanded_passages)
retrieved_passages.append(combined_passages)
# Read passages and extract proximal triples via LLM
proximal_triples = gist_memory_constructor(expanded_passages)
# Add to gist memory
gist_memory.add_triples(proximal_triples)
# Check if we have enough evidence to answer query
is_answerable, reasoning = reason(gist_memory.get_all_triples(), query)
if is_answerable:
break
else:
# Rewrite query for next step
current_query = rewrite(query, gist_memory.get_all_triples(), reasoning)
step += 1
# Link final gist memory triples to passages
gist_passages = []
for triple in gist_memory.get_all_triples():
linked_passages = passageLink(triple)
gist_passages.append(linked_passages)
# Final passage ranking combining all retrieved passages
final_passages = rrf(gist_passages + retrieved_passages)
return final_passages
def SyncGE(query):
"""
SyncGE pipeline implementation.
Parameters:
query: Input query
"""
# Base retrieval for current query (i.e bm25)
base_passages = base_retriever(current_query)
# Read passages and extract proximal triples via LLM
proximal_triples = reader(base_passages, query)
# Link proximal triples to their closest real triples in index
triples = tripleLink(proximal_triples)
# Graph expansion using proximal triples
expanded_passages = graph_expasion(triples, query)
# Combine base and expanded passages, and save them
combined_passages = rrf(base_passages + expanded_passages)
return combined_passages
def diverse_triple_search(q, t_list, b, l, γ):
"""
Performs diverse triple beam search, used in our proposed NaiveGE or SyncGE.
Parameters:
q: query
b: beam size
t_list: initial triples
l: maximum length
γ: hyperparameter for diversity
"""
# Initialize beam for first step
B_0 = []
# Score individual triples
for t in t_list:
s = score(q, [t])
B_0.add((s, [t]))
B_0 = top(B_0, b) # Keep top b scoring triples
# Iterative beam search
for i in range(1, l):
B = []
for (s, T) in B_{i-1}:
V = [] # Candidates from current path
# Explore neighboring triples
for t in get_neighbours(T[-1]):
# Skip if triple already used
if exists(t, B_{i-1}):
continue
# Score new path with concatenated triple
new_path = T + [t]
s_new = s + score(q, new_path)
V.add((s_new, new_path))
sort(V, descending=True)
# Apply diversity penalty
for n in range(len(V)):
s_new, path = V[n]
penalty = exp(-min(n, γ)/γ)
B.add((s_new * penalty, path))
B_i = top(B, b) # Keep top b paths
return B_i
GeAR achieves state-of-the-art retrieval performance across multiple multi-hop QA benchmarks
GeAR achieves state-of-the-art retrieval performance across multiple multi-hop QA benchmarks
@article{shen2024gear,
title={GeAR: Graph-enhanced Agent for Retrieval-augmented Generation},
author={Shen, Zhili and Diao, Chenxin and Vougiouklis, Pavlos and Merita, Pascual and Piramanayagam, Shriram and Graux, Damien and Tu, Dandan and Jiang, Zeren and Lai, Ruofei and Ren, Yang and Pan, Jeff Z.},
journal={arXiv preprint arXiv:2412.18431},
year={2024}
}