dataset metadata fix
This commit is contained in:
parent
c53786d229
commit
9e258c495d
@ -617,7 +617,7 @@ class DocumentDetailApi(DocumentResource):
|
|||||||
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
|
||||||
|
|
||||||
if metadata == "only":
|
if metadata == "only":
|
||||||
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
|
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
|
||||||
elif metadata == "without":
|
elif metadata == "without":
|
||||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||||
document_process_rules = document.dataset_process_rule.to_dict()
|
document_process_rules = document.dataset_process_rule.to_dict()
|
||||||
@ -678,7 +678,7 @@ class DocumentDetailApi(DocumentResource):
|
|||||||
"disabled_by": document.disabled_by,
|
"disabled_by": document.disabled_by,
|
||||||
"archived": document.archived,
|
"archived": document.archived,
|
||||||
"doc_type": document.doc_type,
|
"doc_type": document.doc_type,
|
||||||
"doc_metadata": document.doc_metadata,
|
"doc_metadata": document.doc_metadata_details,
|
||||||
"segment_count": document.segment_count,
|
"segment_count": document.segment_count,
|
||||||
"average_segment_length": document.average_segment_length,
|
"average_segment_length": document.average_segment_length,
|
||||||
"hit_count": document.hit_count,
|
"hit_count": document.hit_count,
|
||||||
|
|||||||
@ -197,8 +197,8 @@ class AnalyticdbVectorBySql:
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = "WHERE 1=1"
|
where_clause = "WHERE 1=1"
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause += f"AND metadata_->>'doc_id' IN ({doc_ids})"
|
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||||
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
score_threshold = float(kwargs.get("score_threshold") or 0.0)
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
query_vector_str = json.dumps(query_vector)
|
query_vector_str = json.dumps(query_vector)
|
||||||
@ -228,8 +228,8 @@ class AnalyticdbVectorBySql:
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause += f"AND metadata_->>'doc_id' IN ({doc_ids})"
|
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""SELECT id, vector, page_content, metadata_,
|
f"""SELECT id, vector, page_content, metadata_,
|
||||||
|
|||||||
@ -125,12 +125,12 @@ class BaiduVector(BaseVector):
|
|||||||
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
query_vector = [float(val) if isinstance(val, np.float64) else val for val in query_vector]
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
anns = AnnSearch(
|
anns = AnnSearch(
|
||||||
vector_field=self.field_vector,
|
vector_field=self.field_vector,
|
||||||
vector_floats=query_vector,
|
vector_floats=query_vector,
|
||||||
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
params=HNSWSearchParams(ef=kwargs.get("ef", 10), limit=kwargs.get("top_k", 4)),
|
||||||
filter=f"doc_id IN ({doc_ids})",
|
filter=f"document_id IN ({document_ids})",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
anns = AnnSearch(
|
anns = AnnSearch(
|
||||||
|
|||||||
@ -100,7 +100,7 @@ class ChromaVector(BaseVector):
|
|||||||
results: QueryResult = collection.query(
|
results: QueryResult = collection.query(
|
||||||
query_embeddings=query_vector,
|
query_embeddings=query_vector,
|
||||||
n_results=kwargs.get("top_k", 4),
|
n_results=kwargs.get("top_k", 4),
|
||||||
where={"doc_id": {"$in": document_ids_filter}},
|
where={"document_id": {"$in": document_ids_filter}},
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
results: QueryResult = collection.query(query_embeddings=query_vector, n_results=kwargs.get("top_k", 4))
|
||||||
|
|||||||
@ -119,7 +119,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
knn["filter"] = {"terms": {"metadata.doc_id": document_ids_filter}}
|
knn["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||||
|
|
||||||
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
results = self._client.search(index=self._collection_name, knn=knn, size=top_k)
|
||||||
|
|
||||||
@ -150,7 +150,7 @@ class ElasticSearchVector(BaseVector):
|
|||||||
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
query_str = {"match": {Field.CONTENT_KEY.value: query}}
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
query_str["filter"] = {"terms": {"metadata.doc_id": document_ids_filter}}
|
query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||||
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
|
||||||
docs = []
|
docs = []
|
||||||
for hit in results["hits"]["hits"]:
|
for hit in results["hits"]["hits"]:
|
||||||
|
|||||||
@ -171,7 +171,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
filters = []
|
filters = []
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filters.append({"terms": {"metadata.doc_id": document_ids_filter}})
|
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||||
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -214,7 +214,7 @@ class LindormVectorStore(BaseVector):
|
|||||||
filters = kwargs.get("filter", [])
|
filters = kwargs.get("filter", [])
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filters.append({"terms": {"metadata.doc_id": document_ids_filter}})
|
filters.append({"terms": {"metadata.document_id": document_ids_filter}})
|
||||||
routing = self._routing
|
routing = self._routing
|
||||||
full_text_query = default_text_search_query(
|
full_text_query = default_text_search_query(
|
||||||
query_text=query,
|
query_text=query,
|
||||||
|
|||||||
@ -221,8 +221,8 @@ class MilvusVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
filter = ""
|
filter = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
filter = f'metadata["doc_id"] in ({doc_ids})'
|
filter = f'metadata["document_id"] in ({document_ids})'
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
data=[query_vector],
|
data=[query_vector],
|
||||||
@ -248,8 +248,8 @@ class MilvusVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
filter = ""
|
filter = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
filter = f'metadata["doc_id"] in ({doc_ids})'
|
filter = f'metadata["document_id"] in ({document_ids})'
|
||||||
|
|
||||||
results = self._client.search(
|
results = self._client.search(
|
||||||
collection_name=self._collection_name,
|
collection_name=self._collection_name,
|
||||||
|
|||||||
@ -133,8 +133,8 @@ class MyScaleVector(BaseVector):
|
|||||||
)
|
)
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_str = f"{where_str} AND metadata['doc_id'] in ({doc_ids})"
|
where_str = f"{where_str} AND metadata['document_id'] in ({document_ids})"
|
||||||
sql = f"""
|
sql = f"""
|
||||||
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
SELECT text, vector, metadata, {dist} as dist FROM {self._config.database}.{self._collection_name}
|
||||||
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
{where_str} ORDER BY dist {order.value} LIMIT {top_k}
|
||||||
|
|||||||
@ -157,8 +157,8 @@ class OceanBaseVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = None
|
where_clause = None
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause = f"metadata->>'$.doc_id' in ({doc_ids})"
|
where_clause = f"metadata->>'$.document_id' in ({document_ids})"
|
||||||
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
ef_search = kwargs.get("ef_search", self._hnsw_ef_search)
|
||||||
if ef_search != self._hnsw_ef_search:
|
if ef_search != self._hnsw_ef_search:
|
||||||
self._client.set_ob_hnsw_ef_search(ef_search)
|
self._client.set_ob_hnsw_ef_search(ef_search)
|
||||||
|
|||||||
@ -156,7 +156,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
}
|
}
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
query["query"] = {"terms": {"metadata.doc_id": document_ids_filter}}
|
query["query"] = {"terms": {"metadata.document_id": document_ids_filter}}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self._client.search(index=self._collection_name.lower(), body=query)
|
response = self._client.search(index=self._collection_name.lower(), body=query)
|
||||||
@ -184,7 +184,7 @@ class OpenSearchVector(BaseVector):
|
|||||||
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}}
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
full_text_query["query"]["terms"] = {"metadata.doc_id": document_ids_filter}
|
full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter}
|
||||||
|
|
||||||
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
response = self._client.search(index=self._collection_name.lower(), body=full_text_query)
|
||||||
|
|
||||||
|
|||||||
@ -188,8 +188,8 @@ class OracleVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause = f"WHERE metadata->>'doc_id' in ({doc_ids})"
|
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
|
||||||
@ -249,8 +249,8 @@ class OracleVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause = f" AND metadata->>'doc_id' in ({doc_ids}) "
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"select meta, text, embedding FROM {self.table_name}"
|
f"select meta, text, embedding FROM {self.table_name}"
|
||||||
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
|
||||||
|
|||||||
@ -191,7 +191,7 @@ class PGVectoRS(BaseVector):
|
|||||||
)
|
)
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
stmt = stmt.where(self._table.meta["doc_id"].in_(document_ids_filter))
|
stmt = stmt.where(self._table.meta["document_id"].in_(document_ids_filter))
|
||||||
res = session.execute(stmt)
|
res = session.execute(stmt)
|
||||||
results = [(row[0], row[1]) for row in res]
|
results = [(row[0], row[1]) for row in res]
|
||||||
|
|
||||||
|
|||||||
@ -158,8 +158,8 @@ class PGVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause = f" WHERE metadata->>'doc_id' in ({doc_ids}) "
|
where_clause = f" WHERE metadata->>'document_id' in ({document_ids}) "
|
||||||
|
|
||||||
with self._get_cursor() as cur:
|
with self._get_cursor() as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
@ -185,8 +185,8 @@ class PGVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause = f" AND metadata->>'doc_id' in ({doc_ids}) "
|
where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
|
||||||
FROM {self.table_name}
|
FROM {self.table_name}
|
||||||
|
|||||||
@ -334,7 +334,7 @@ class QdrantVector(BaseVector):
|
|||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filter.must.append(
|
filter.must.append(
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="metadata.doc_id",
|
key="metadata.document_id",
|
||||||
match=models.MatchAny(any=document_ids_filter),
|
match=models.MatchAny(any=document_ids_filter),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -388,7 +388,7 @@ class QdrantVector(BaseVector):
|
|||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
scroll_filter.must.append(
|
scroll_filter.must.append(
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="metadata.doc_id",
|
key="metadata.document_id",
|
||||||
match=models.MatchAny(any=document_ids_filter),
|
match=models.MatchAny(any=document_ids_filter),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -226,7 +226,7 @@ class RelytVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
filter = kwargs.get("filter", {})
|
filter = kwargs.get("filter", {})
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filter["doc_id"] = document_ids_filter
|
filter["document_id"] = document_ids_filter
|
||||||
results = self.similarity_search_with_score_by_vector(
|
results = self.similarity_search_with_score_by_vector(
|
||||||
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
k=int(kwargs.get("top_k", 4)), embedding=query_vector, filter=filter
|
||||||
)
|
)
|
||||||
|
|||||||
@ -151,7 +151,7 @@ class TencentVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
filter = None
|
filter = None
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filter = Filter(Filter.In("metadata.doc_id", document_ids_filter))
|
filter = Filter(Filter.In("metadata.document_id", document_ids_filter))
|
||||||
res = self._db.collection(self._collection_name).search(
|
res = self._db.collection(self._collection_name).search(
|
||||||
vectors=[query_vector],
|
vectors=[query_vector],
|
||||||
filter=filter,
|
filter=filter,
|
||||||
|
|||||||
@ -330,7 +330,7 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filter.must.append(
|
filter.must.append(
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="metadata.doc_id",
|
key="metadata.document_id",
|
||||||
match=models.MatchAny(any=document_ids_filter),
|
match=models.MatchAny(any=document_ids_filter),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -380,7 +380,7 @@ class TidbOnQdrantVector(BaseVector):
|
|||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
scroll_filter.must.append(
|
scroll_filter.must.append(
|
||||||
models.FieldCondition(
|
models.FieldCondition(
|
||||||
key="metadata.doc_id",
|
key="metadata.document_id",
|
||||||
match=models.MatchAny(any=document_ids_filter),
|
match=models.MatchAny(any=document_ids_filter),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -199,8 +199,8 @@ class TiDBVector(BaseVector):
|
|||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
where_clause = ""
|
where_clause = ""
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
doc_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
where_clause = f" WHERE meta->>'$.doc_id' in ({doc_ids}) "
|
where_clause = f" WHERE meta->>'$.document_id' in ({document_ids}) "
|
||||||
|
|
||||||
with Session(self._engine) as session:
|
with Session(self._engine) as session:
|
||||||
select_statement = sql_text(f"""
|
select_statement = sql_text(f"""
|
||||||
|
|||||||
@ -90,7 +90,8 @@ class UpstashVector(BaseVector):
|
|||||||
top_k = kwargs.get("top_k", 4)
|
top_k = kwargs.get("top_k", 4)
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
filter = f"doc_id in ({', '.join(f"'{id}'" for id in document_ids_filter)})"
|
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
|
||||||
|
filter = f"document_id in ({document_ids})"
|
||||||
else:
|
else:
|
||||||
filter = ""
|
filter = ""
|
||||||
result = self.index.query(
|
result = self.index.query(
|
||||||
|
|||||||
@ -180,7 +180,7 @@ class VikingDBVector(BaseVector):
|
|||||||
docs = self._get_search_res(results, score_threshold)
|
docs = self._get_search_res(results, score_threshold)
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
docs = [doc for doc in docs if doc.metadata.get("doc_id") in document_ids_filter]
|
docs = [doc for doc in docs if doc.metadata.get("document_id") in document_ids_filter]
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
def _get_search_res(self, results, score_threshold) -> list[Document]:
|
||||||
|
|||||||
@ -189,7 +189,7 @@ class WeaviateVector(BaseVector):
|
|||||||
vector = {"vector": query_vector}
|
vector = {"vector": query_vector}
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
where_filter = {"operator": "ContainsAny", "path": ["doc_id"], "valueTextArray": document_ids_filter}
|
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||||
query_obj = query_obj.with_where(where_filter)
|
query_obj = query_obj.with_where(where_filter)
|
||||||
result = (
|
result = (
|
||||||
query_obj.with_near_vector(vector)
|
query_obj.with_near_vector(vector)
|
||||||
@ -237,7 +237,7 @@ class WeaviateVector(BaseVector):
|
|||||||
query_obj = self._client.query.get(collection_name, properties)
|
query_obj = self._client.query.get(collection_name, properties)
|
||||||
document_ids_filter = kwargs.get("document_ids_filter")
|
document_ids_filter = kwargs.get("document_ids_filter")
|
||||||
if document_ids_filter:
|
if document_ids_filter:
|
||||||
where_filter = {"operator": "ContainsAny", "path": ["doc_id"], "valueTextArray": document_ids_filter}
|
where_filter = {"operator": "ContainsAny", "path": ["document_id"], "valueTextArray": document_ids_filter}
|
||||||
query_obj = query_obj.with_where(where_filter)
|
query_obj = query_obj.with_where(where_filter)
|
||||||
query_obj = query_obj.with_additional(["vector"])
|
query_obj = query_obj.with_additional(["vector"])
|
||||||
properties = ["text"]
|
properties = ["text"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user