From acba38dabd77785519cb733200a08562ca23d33c Mon Sep 17 00:00:00 2001 From: "gcp-cherry-pick-bot[bot]" <98988430+gcp-cherry-pick-bot[bot]@users.noreply.github.com> Date: Fri, 11 Jul 2025 11:48:51 -0400 Subject: [PATCH] language_models: Refresh the list of models when the LLM token is refreshed (cherry-pick #34222) (#34294) Cherry-picked language_models: Refresh the list of models when the LLM token is refreshed (#34222) This PR makes it so we refresh the list of models whenever the LLM token is refreshed. This allows us to add or remove models based on the plan in the new token. Release Notes: - Fixed model list not refreshing when subscribing to Zed Pro. --------- Co-authored-by: Bennet Bo Fenner Co-authored-by: Marshall Bowers Co-authored-by: Bennet Bo Fenner --- crates/language_models/src/provider/cloud.rs | 87 ++++++++++---------- 1 file changed, 44 insertions(+), 43 deletions(-) diff --git a/crates/language_models/src/provider/cloud.rs b/crates/language_models/src/provider/cloud.rs index 9b7fee228a..518f386ebe 100644 --- a/crates/language_models/src/provider/cloud.rs +++ b/crates/language_models/src/provider/cloud.rs @@ -166,46 +166,9 @@ impl State { } let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; - cx.update(|cx| { - this.update(cx, |this, cx| { - let mut models = Vec::new(); - - for model in response.models { - models.push(Arc::new(model.clone())); - - // Right now we represent thinking variants of models as separate models on the client, - // so we need to insert variants for any model that supports thinking. - if model.supports_thinking { - models.push(Arc::new(zed_llm_client::LanguageModel { - id: zed_llm_client::LanguageModelId( - format!("{}-thinking", model.id).into(), - ), - display_name: format!("{} Thinking", model.display_name), - ..model - })); - } - } - - this.default_model = models - .iter() - .find(|model| model.id == response.default_model) - .cloned(); - this.default_fast_model = models - .iter() - .find(|model| model.id == response.default_fast_model) - .cloned(); - this.recommended_models = response - .recommended_models - .iter() - .filter_map(|id| models.iter().find(|model| &model.id == id)) - .cloned() - .collect(); - this.models = models; - cx.notify(); - }) - })??; - - anyhow::Ok(()) + this.update(cx, |this, cx| { + this.update_models(response, cx); + }) }) .await .context("failed to fetch Zed models") @@ -216,12 +179,15 @@ impl State { }), _llm_token_subscription: cx.subscribe( &refresh_llm_token_listener, - |this, _listener, _event, cx| { + move |this, _listener, _event, cx| { let client = this.client.clone(); let llm_api_token = this.llm_api_token.clone(); - cx.spawn(async move |_this, _cx| { + cx.spawn(async move |this, cx| { llm_api_token.refresh(&client).await?; - anyhow::Ok(()) + let response = Self::fetch_models(client, llm_api_token, use_cloud).await?; + this.update(cx, |this, cx| { + this.update_models(response, cx); + }) }) .detach_and_log_err(cx); }, @@ -264,6 +230,41 @@ impl State { })); } + fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context) { + let mut models = Vec::new(); + + for model in response.models { + models.push(Arc::new(model.clone())); + + // Right now we represent thinking variants of models as separate models on the client, + // so we need to insert variants for any model that supports thinking. + if model.supports_thinking { + models.push(Arc::new(zed_llm_client::LanguageModel { + id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()), + display_name: format!("{} Thinking", model.display_name), + ..model + })); + } + } + + self.default_model = models + .iter() + .find(|model| model.id == response.default_model) + .cloned(); + self.default_fast_model = models + .iter() + .find(|model| model.id == response.default_fast_model) + .cloned(); + self.recommended_models = response + .recommended_models + .iter() + .filter_map(|id| models.iter().find(|model| &model.id == id)) + .cloned() + .collect(); + self.models = models; + cx.notify(); + } + async fn fetch_models( client: Arc, llm_api_token: LlmApiToken,