fix term weight issue (#3306)
### What problem does this PR solve? ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
parent
74d1eeb4d3
commit
004487cca0
@ -34,12 +34,13 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
class Benchmark:
|
class Benchmark:
|
||||||
def __init__(self, kb_id):
|
def __init__(self, kb_id):
|
||||||
e, kb = KnowledgebaseService.get_by_id(kb_id)
|
e, self.kb = KnowledgebaseService.get_by_id(kb_id)
|
||||||
self.similarity_threshold = kb.similarity_threshold
|
self.similarity_threshold = self.kb.similarity_threshold
|
||||||
self.vector_similarity_weight = kb.vector_similarity_weight
|
self.vector_similarity_weight = self.kb.vector_similarity_weight
|
||||||
self.embd_mdl = LLMBundle(kb.tenant_id, LLMType.EMBEDDING, llm_name=kb.embd_id, lang=kb.language)
|
self.embd_mdl = LLMBundle(self.kb.tenant_id, LLMType.EMBEDDING, llm_name=self.kb.embd_id, lang=self.kb.language)
|
||||||
|
|
||||||
def _get_benchmarks(self, query, dataset_idxnm, count=16):
|
def _get_benchmarks(self, query, dataset_idxnm, count=16):
|
||||||
|
|
||||||
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
req = {"question": query, "size": count, "vector": True, "similarity": self.similarity_threshold}
|
||||||
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
|
sres = retrievaler.search(req, search.index_name(dataset_idxnm), self.embd_mdl)
|
||||||
return sres
|
return sres
|
||||||
@ -48,11 +49,15 @@ class Benchmark:
|
|||||||
run = defaultdict(dict)
|
run = defaultdict(dict)
|
||||||
query_list = list(qrels.keys())
|
query_list = list(qrels.keys())
|
||||||
for query in query_list:
|
for query in query_list:
|
||||||
sres = self._get_benchmarks(query, dataset_idxnm)
|
|
||||||
sim, _, _ = retrievaler.rerank(sres, query, 1 - self.vector_similarity_weight,
|
ranks = retrievaler.retrieval(query, self.embd_mdl, dataset_idxnm.replace("ragflow_", ""),
|
||||||
self.vector_similarity_weight)
|
[self.kb.id], 0, 30,
|
||||||
for index, id in enumerate(sres.ids):
|
0.0, self.vector_similarity_weight)
|
||||||
run[query][id] = sim[index]
|
for c in ranks["chunks"]:
|
||||||
|
if "vector" in c:
|
||||||
|
del c["vector"]
|
||||||
|
run[query][c["chunk_id"]] = c["similarity"]
|
||||||
|
|
||||||
return run
|
return run
|
||||||
|
|
||||||
def embedding(self, docs, batch_size=16):
|
def embedding(self, docs, batch_size=16):
|
||||||
@ -99,7 +104,8 @@ class Benchmark:
|
|||||||
query = data.iloc[i]['query']
|
query = data.iloc[i]['query']
|
||||||
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
for rel, text in zip(data.iloc[i]['passages']['is_selected'], data.iloc[i]['passages']['passage_text']):
|
||||||
d = {
|
d = {
|
||||||
"id": get_uuid()
|
"id": get_uuid(),
|
||||||
|
"kb_id": self.kb.id
|
||||||
}
|
}
|
||||||
tokenize(d, text, "english")
|
tokenize(d, text, "english")
|
||||||
docs.append(d)
|
docs.append(d)
|
||||||
@ -208,6 +214,8 @@ class Benchmark:
|
|||||||
scores = sorted(scores, key=lambda kk: kk[1])
|
scores = sorted(scores, key=lambda kk: kk[1])
|
||||||
for score in scores[:10]:
|
for score in scores[:10]:
|
||||||
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
f.write('- text: ' + str(texts[score[0]]) + '\t qrel: ' + str(score[1]) + '\n')
|
||||||
|
json.dump(qrels, open(os.path.join(file_path, dataset + '.qrels.json'), "w+"), indent=2)
|
||||||
|
json.dump(run, open(os.path.join(file_path, dataset + '.run.json'), "w+"), indent=2)
|
||||||
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
print(os.path.join(file_path, dataset + '_result.md'), 'Saved!')
|
||||||
|
|
||||||
def __call__(self, dataset, file_path, miracl_corpus=''):
|
def __call__(self, dataset, file_path, miracl_corpus=''):
|
||||||
|
|||||||
@ -211,8 +211,8 @@ class Dealer:
|
|||||||
continue
|
continue
|
||||||
if not isinstance(v, type("")):
|
if not isinstance(v, type("")):
|
||||||
m[n] = str(m[n])
|
m[n] = str(m[n])
|
||||||
if n.find("tks") > 0:
|
#if n.find("tks") > 0:
|
||||||
m[n] = rmSpace(m[n])
|
# m[n] = rmSpace(m[n])
|
||||||
|
|
||||||
if m:
|
if m:
|
||||||
res[d["id"]] = m
|
res[d["id"]] = m
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user