diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 78c72cd6..b6d3202a 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -273,9 +273,22 @@ class InfinityConnection(DocStoreConnection): for essential_field in ["id"]: if essential_field not in selectFields: selectFields.append(essential_field) + score_func = "" + score_column = "" + for matchExpr in matchExprs: + if isinstance(matchExpr, MatchTextExpr): + score_func = "score()" + score_column = "SCORE" + break + if not score_func: + for matchExpr in matchExprs: + if isinstance(matchExpr, MatchDenseExpr): + score_func = "similarity()" + score_column = "SIMILARITY" + break if matchExprs: - for essential_field in ["score()", PAGERANK_FLD]: - selectFields.append(essential_field) + selectFields.append(score_func) + selectFields.append(PAGERANK_FLD) # Prepare expressions common to all tables filter_cond = None @@ -364,7 +377,9 @@ class InfinityConnection(DocStoreConnection): self.connPool.release_conn(inf_conn) res = concat_dataframes(df_list, selectFields) if matchExprs: - res = res.sort(pl.col("SCORE") + pl.col(PAGERANK_FLD), descending=True, maintain_order=True) + res = res.sort(pl.col(score_column) + pl.col(PAGERANK_FLD), descending=True, maintain_order=True) + if score_column and score_column != "SCORE": + res = res.rename({score_column: "SCORE"}) res = res.limit(limit) logger.debug(f"INFINITY search final result: {str(res)}") return res, total_hits_count