diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 6e94cb69db..93dcb280ed 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -393,6 +393,28 @@ class QdrantVector(BaseVector): return documents + def update_metadata(self, document_id: str, metadata: dict) -> None: + from qdrant_client.http import models + scroll_filter = models.Filter( + must=[ + models.FieldCondition( + key="group_id", + match=models.MatchValue(value=self._group_id), + ), + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchValue(value=document_id), + ), + ] + ) + self._client.set_payload( + collection_name=self._collection_name, + filter=scroll_filter, + payload={ + Field.METADATA_KEY.value: metadata, + }, + ) + def _reload_if_needed(self): if isinstance(self._client, QdrantLocal): self._client = cast(QdrantLocal, self._client) diff --git a/api/core/rag/datasource/vdb/vector_base.py b/api/core/rag/datasource/vdb/vector_base.py index edfce2edd8..2b10504630 100644 --- a/api/core/rag/datasource/vdb/vector_base.py +++ b/api/core/rag/datasource/vdb/vector_base.py @@ -48,6 +48,10 @@ class BaseVector(ABC): @abstractmethod def delete(self) -> None: raise NotImplementedError + + @abstractmethod + def update_metadata(self, document_id: str, metadata: dict) -> None: + raise NotImplementedError def _filter_duplicate_texts(self, texts: list[Document]) -> list[Document]: for text in texts.copy(): diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py index bedab5750f..b96074dc0d 100644 --- a/api/fields/dataset_fields.py +++ b/api/fields/dataset_fields.py @@ -87,3 +87,9 @@ dataset_query_detail_fields = { "created_by": fields.String, "created_at": TimestampField, } + +dataset_metadata_fields = { + "id": fields.String, + "type": fields.String, + "name": fields.String, +} diff --git a/api/models/dataset.py b/api/models/dataset.py index 1cf3dc42fe..3d5ee70e30 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -926,3 +926,41 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] document_id = db.Column(StringUUID, nullable=False) notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + + +class DatasetMetadata(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_metadatas" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), + db.Index("dataset_metadata_tenant_idx", "tenant_id"), + db.Index("dataset_metadata_dataset_idx", "dataset_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + type = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) + created_by = db.Column(StringUUID, nullable=False) + updated_by = db.Column(StringUUID, nullable=True) + + +class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] + __tablename__ = "dataset_metadata_bindings" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), + db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"), + db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"), + db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"), + db.Index("dataset_metadata_binding_document_idx", "document_id"), + ) + + id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + tenant_id = db.Column(StringUUID, nullable=False) + dataset_id = db.Column(StringUUID, nullable=False) + metadata_id = db.Column(StringUUID, nullable=False) + document_id = db.Column(StringUUID, nullable=False) + created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_by = db.Column(StringUUID, nullable=False) diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index f14c5b513a..f8f6b6c132 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -124,3 +124,8 @@ class SegmentUpdateArgs(BaseModel): class ChildChunkUpdateArgs(BaseModel): id: Optional[str] = None content: str + + +class MetadataArgs(BaseModel): + type: Literal["string", "number", "time"] + name: str