Spaces:
Sleeping
Sleeping
| use async_trait::async_trait; | |
| use uuid::Uuid; | |
| use crate::chroma_proto; | |
| use crate::config::{Configurable, WorkerConfig}; | |
| use crate::types::{CollectionConversionError, SegmentConversionError}; | |
| use crate::{ | |
| chroma_proto::sys_db_client, | |
| errors::{ChromaError, ErrorCodes}, | |
| types::{Collection, Segment, SegmentScope}, | |
| }; | |
| use thiserror::Error; | |
| use super::config::SysDbConfig; | |
| const DEFAULT_DATBASE: &str = "default_database"; | |
| const DEFAULT_TENANT: &str = "default_tenant"; | |
| pub(crate) trait SysDb: Send + Sync + SysDbClone { | |
| async fn get_collections( | |
| &mut self, | |
| collection_id: Option<Uuid>, | |
| topic: Option<String>, | |
| name: Option<String>, | |
| tenant: Option<String>, | |
| database: Option<String>, | |
| ) -> Result<Vec<Collection>, GetCollectionsError>; | |
| async fn get_segments( | |
| &mut self, | |
| id: Option<Uuid>, | |
| r#type: Option<String>, | |
| scope: Option<SegmentScope>, | |
| topic: Option<String>, | |
| collection: Option<Uuid>, | |
| ) -> Result<Vec<Segment>, GetSegmentsError>; | |
| } | |
| // We'd like to be able to clone the trait object, so we need to use the | |
| // "clone box" pattern. See https://stackoverflow.com/questions/30353462/how-to-clone-a-struct-storing-a-boxed-trait-object#comment48814207_30353928 | |
| // https://chat.openai.com/share/b3eae92f-0b80-446f-b79d-6287762a2420 | |
| pub(crate) trait SysDbClone { | |
| fn clone_box(&self) -> Box<dyn SysDb>; | |
| } | |
| impl<T> SysDbClone for T | |
| where | |
| T: 'static + SysDb + Clone, | |
| { | |
| fn clone_box(&self) -> Box<dyn SysDb> { | |
| Box::new(self.clone()) | |
| } | |
| } | |
| impl Clone for Box<dyn SysDb> { | |
| fn clone(&self) -> Box<dyn SysDb> { | |
| self.clone_box() | |
| } | |
| } | |
| // Since this uses tonic transport channel, cloning is cheap. Each client only supports | |
| // one inflight request at a time, so we need to clone the client for each requester. | |
| pub(crate) struct GrpcSysDb { | |
| client: sys_db_client::SysDbClient<tonic::transport::Channel>, | |
| } | |
| pub(crate) enum GrpcSysDbError { | |
| FailedToConnect( tonic::transport::Error), | |
| } | |
| impl ChromaError for GrpcSysDbError { | |
| fn code(&self) -> ErrorCodes { | |
| match self { | |
| GrpcSysDbError::FailedToConnect(_) => ErrorCodes::Internal, | |
| } | |
| } | |
| } | |
| impl Configurable for GrpcSysDb { | |
| async fn try_from_config(worker_config: &WorkerConfig) -> Result<Self, Box<dyn ChromaError>> { | |
| match &worker_config.sysdb { | |
| SysDbConfig::Grpc(my_config) => { | |
| let host = &my_config.host; | |
| let port = &my_config.port; | |
| println!("Connecting to sysdb at {}:{}", host, port); | |
| let connection_string = format!("http://{}:{}", host, port); | |
| let client = sys_db_client::SysDbClient::connect(connection_string).await; | |
| match client { | |
| Ok(client) => { | |
| return Ok(GrpcSysDb { client: client }); | |
| } | |
| Err(e) => { | |
| return Err(Box::new(GrpcSysDbError::FailedToConnect(e))); | |
| } | |
| } | |
| } | |
| } | |
| } | |
| } | |
| impl SysDb for GrpcSysDb { | |
| async fn get_collections( | |
| &mut self, | |
| collection_id: Option<Uuid>, | |
| topic: Option<String>, | |
| name: Option<String>, | |
| tenant: Option<String>, | |
| database: Option<String>, | |
| ) -> Result<Vec<Collection>, GetCollectionsError> { | |
| // TODO: move off of status into our own error type | |
| let collection_id_str; | |
| match collection_id { | |
| Some(id) => { | |
| collection_id_str = Some(id.to_string()); | |
| } | |
| None => { | |
| collection_id_str = None; | |
| } | |
| } | |
| let res = self | |
| .client | |
| .get_collections(chroma_proto::GetCollectionsRequest { | |
| id: collection_id_str, | |
| topic: topic, | |
| name: name, | |
| tenant: if tenant.is_some() { | |
| tenant.unwrap() | |
| } else { | |
| DEFAULT_TENANT.to_string() | |
| }, | |
| database: if database.is_some() { | |
| database.unwrap() | |
| } else { | |
| DEFAULT_DATBASE.to_string() | |
| }, | |
| }) | |
| .await; | |
| match res { | |
| Ok(res) => { | |
| let collections = res.into_inner().collections; | |
| let collections = collections | |
| .into_iter() | |
| .map(|proto_collection| proto_collection.try_into()) | |
| .collect::<Result<Vec<Collection>, CollectionConversionError>>(); | |
| match collections { | |
| Ok(collections) => { | |
| return Ok(collections); | |
| } | |
| Err(e) => { | |
| return Err(GetCollectionsError::ConversionError(e)); | |
| } | |
| } | |
| } | |
| Err(e) => { | |
| return Err(GetCollectionsError::FailedToGetCollections(e)); | |
| } | |
| } | |
| } | |
| async fn get_segments( | |
| &mut self, | |
| id: Option<Uuid>, | |
| r#type: Option<String>, | |
| scope: Option<SegmentScope>, | |
| topic: Option<String>, | |
| collection: Option<Uuid>, | |
| ) -> Result<Vec<Segment>, GetSegmentsError> { | |
| let res = self | |
| .client | |
| .get_segments(chroma_proto::GetSegmentsRequest { | |
| // TODO: modularize | |
| id: if id.is_some() { | |
| Some(id.unwrap().to_string()) | |
| } else { | |
| None | |
| }, | |
| r#type: r#type, | |
| scope: if scope.is_some() { | |
| Some(scope.unwrap() as i32) | |
| } else { | |
| None | |
| }, | |
| topic: topic, | |
| collection: if collection.is_some() { | |
| Some(collection.unwrap().to_string()) | |
| } else { | |
| None | |
| }, | |
| }) | |
| .await; | |
| match res { | |
| Ok(res) => { | |
| let segments = res.into_inner().segments; | |
| let converted_segments = segments | |
| .into_iter() | |
| .map(|proto_segment| proto_segment.try_into()) | |
| .collect::<Result<Vec<Segment>, SegmentConversionError>>(); | |
| match converted_segments { | |
| Ok(segments) => { | |
| return Ok(segments); | |
| } | |
| Err(e) => { | |
| return Err(GetSegmentsError::ConversionError(e)); | |
| } | |
| } | |
| } | |
| Err(e) => { | |
| return Err(GetSegmentsError::FailedToGetSegments(e)); | |
| } | |
| } | |
| } | |
| } | |
| // TODO: This should use our sysdb errors from the proto definition | |
| // We will have to do an error uniformization pass at some point | |
| pub(crate) enum GetCollectionsError { | |
| FailedToGetCollections( tonic::Status), | |
| ConversionError( CollectionConversionError), | |
| } | |
| impl ChromaError for GetCollectionsError { | |
| fn code(&self) -> ErrorCodes { | |
| match self { | |
| GetCollectionsError::FailedToGetCollections(_) => ErrorCodes::Internal, | |
| GetCollectionsError::ConversionError(_) => ErrorCodes::Internal, | |
| } | |
| } | |
| } | |
| pub(crate) enum GetSegmentsError { | |
| FailedToGetSegments( tonic::Status), | |
| ConversionError( SegmentConversionError), | |
| } | |
| impl ChromaError for GetSegmentsError { | |
| fn code(&self) -> ErrorCodes { | |
| match self { | |
| GetSegmentsError::FailedToGetSegments(_) => ErrorCodes::Internal, | |
| GetSegmentsError::ConversionError(_) => ErrorCodes::Internal, | |
| } | |
| } | |
| } | |