diff --git a/crates/collab/src/llm.rs b/crates/collab/src/llm.rs index 2072861363..3d13b6f812 100644 --- a/crates/collab/src/llm.rs +++ b/crates/collab/src/llm.rs @@ -27,7 +27,7 @@ pub use token::*; pub struct LlmState { pub config: Config, pub executor: Executor, - pub db: Option>, + pub db: Arc, pub http_client: IsahcHttpClient, active_user_count: RwLock, ActiveUserCount)>>, } @@ -36,25 +36,20 @@ const ACTIVE_USER_COUNT_CACHE_DURATION: Duration = Duration::seconds(30); impl LlmState { pub async fn new(config: Config, executor: Executor) -> Result> { - // TODO: This is temporary until we have the LLM database stood up. - let db = if config.is_development() { - let database_url = config - .llm_database_url - .as_ref() - .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; - let max_connections = config - .llm_database_max_connections - .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; + let database_url = config + .llm_database_url + .as_ref() + .ok_or_else(|| anyhow!("missing LLM_DATABASE_URL"))?; + let max_connections = config + .llm_database_max_connections + .ok_or_else(|| anyhow!("missing LLM_DATABASE_MAX_CONNECTIONS"))?; - let mut db_options = db::ConnectOptions::new(database_url); - db_options.max_connections(max_connections); - let mut db = LlmDatabase::new(db_options, executor.clone()).await?; - db.initialize().await?; + let mut db_options = db::ConnectOptions::new(database_url); + db_options.max_connections(max_connections); + let mut db = LlmDatabase::new(db_options, executor.clone()).await?; + db.initialize().await?; - Some(Arc::new(db)) - } else { - None - }; + let db = Arc::new(db); let user_agent = format!("Zed Server/{}", env!("CARGO_PKG_VERSION")); let http_client = IsahcHttpClient::builder() @@ -62,11 +57,8 @@ impl LlmState { .build() .context("failed to construct http client")?; - let initial_active_user_count = if let Some(db) = &db { - Some((Utc::now(), db.get_active_user_count(Utc::now()).await?)) - } else { - None - }; + let initial_active_user_count = + Some((Utc::now(), db.get_active_user_count(Utc::now()).await?)); let this = Self { config, @@ -88,14 +80,10 @@ impl LlmState { } } - if let Some(db) = &self.db { - let mut cache = self.active_user_count.write().await; - let new_count = db.get_active_user_count(now).await?; - *cache = Some((now, new_count)); - Ok(new_count) - } else { - Ok(ActiveUserCount::default()) - } + let mut cache = self.active_user_count.write().await; + let new_count = self.db.get_active_user_count(now).await?; + *cache = Some((now, new_count)); + Ok(new_count) } } @@ -165,9 +153,7 @@ async fn perform_completion( let user_id = claims.user_id as i32; - if state.db.is_some() { - check_usage_limit(&state, params.provider, &model, &claims).await?; - } + check_usage_limit(&state, params.provider, &model, &claims).await?; match params.provider { LanguageModelProvider::Anthropic => { @@ -199,14 +185,14 @@ async fn perform_completion( ) .await?; - let mut recorder = state.db.clone().map(|db| UsageRecorder { - db, + let mut recorder = UsageRecorder { + db: state.db.clone(), executor: state.executor.clone(), user_id, provider: params.provider, model, token_count: 0, - }); + }; let stream = chunks.map(move |event| { let mut buffer = Vec::new(); @@ -216,10 +202,8 @@ async fn perform_completion( message: anthropic::Response { usage, .. }, } | anthropic::Event::MessageDelta { usage, .. } => { - if let Some(recorder) = &mut recorder { - recorder.token_count += usage.input_tokens.unwrap_or(0) as usize; - recorder.token_count += usage.output_tokens.unwrap_or(0) as usize; - } + recorder.token_count += usage.input_tokens.unwrap_or(0) as usize; + recorder.token_count += usage.output_tokens.unwrap_or(0) as usize; } _ => {} } @@ -349,12 +333,9 @@ async fn check_usage_limit( model_name: &str, claims: &LlmTokenClaims, ) -> Result<()> { - let db = state + let model = state.db.model(provider, model_name)?; + let usage = state .db - .as_ref() - .ok_or_else(|| anyhow!("LLM database not configured"))?; - let model = db.model(provider, model_name)?; - let usage = db .get_usage(claims.user_id as i32, provider, model_name, Utc::now()) .await?; diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 18515192d5..e0f3c7e573 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -248,11 +248,6 @@ async fn setup_app_database(config: &Config) -> Result<()> { } async fn setup_llm_database(config: &Config) -> Result<()> { - // TODO: This is temporary until we have the LLM database stood up. - if !config.is_development() { - return Ok(()); - } - let database_url = config .llm_database_url .as_ref() @@ -298,7 +293,12 @@ async fn handle_liveness_probe( state.db.get_all_users(0, 1).await?; } - if let Some(_llm_state) = llm_state {} + if let Some(llm_state) = llm_state { + llm_state + .db + .get_active_user_count(chrono::Utc::now()) + .await?; + } Ok("ok".to_string()) }