language_models: Refresh the list of models when the LLM token is refreshed (cherry-pick #34222) (#34294)
Cherry-picked language_models: Refresh the list of models when the LLM token is refreshed (#34222) This PR makes it so we refresh the list of models whenever the LLM token is refreshed. This allows us to add or remove models based on the plan in the new token. Release Notes: - Fixed model list not refreshing when subscribing to Zed Pro. --------- Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de> Co-authored-by: Marshall Bowers <git@maxdeviant.com> Co-authored-by: Bennet Bo Fenner <bennetbo@gmx.de>
This commit is contained in:
parent
c1b3111c15
commit
acba38dabd
1 changed files with 44 additions and 43 deletions
|
@ -166,46 +166,9 @@ impl State {
|
|||
}
|
||||
|
||||
let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
|
||||
cx.update(|cx| {
|
||||
this.update(cx, |this, cx| {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(
|
||||
format!("{}-thinking", model.id).into(),
|
||||
),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
this.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
this.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
this.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
this.models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})??;
|
||||
|
||||
anyhow::Ok(())
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
})
|
||||
.await
|
||||
.context("failed to fetch Zed models")
|
||||
|
@ -216,12 +179,15 @@ impl State {
|
|||
}),
|
||||
_llm_token_subscription: cx.subscribe(
|
||||
&refresh_llm_token_listener,
|
||||
|this, _listener, _event, cx| {
|
||||
move |this, _listener, _event, cx| {
|
||||
let client = this.client.clone();
|
||||
let llm_api_token = this.llm_api_token.clone();
|
||||
cx.spawn(async move |_this, _cx| {
|
||||
cx.spawn(async move |this, cx| {
|
||||
llm_api_token.refresh(&client).await?;
|
||||
anyhow::Ok(())
|
||||
let response = Self::fetch_models(client, llm_api_token, use_cloud).await?;
|
||||
this.update(cx, |this, cx| {
|
||||
this.update_models(response, cx);
|
||||
})
|
||||
})
|
||||
.detach_and_log_err(cx);
|
||||
},
|
||||
|
@ -264,6 +230,41 @@ impl State {
|
|||
}));
|
||||
}
|
||||
|
||||
fn update_models(&mut self, response: ListModelsResponse, cx: &mut Context<Self>) {
|
||||
let mut models = Vec::new();
|
||||
|
||||
for model in response.models {
|
||||
models.push(Arc::new(model.clone()));
|
||||
|
||||
// Right now we represent thinking variants of models as separate models on the client,
|
||||
// so we need to insert variants for any model that supports thinking.
|
||||
if model.supports_thinking {
|
||||
models.push(Arc::new(zed_llm_client::LanguageModel {
|
||||
id: zed_llm_client::LanguageModelId(format!("{}-thinking", model.id).into()),
|
||||
display_name: format!("{} Thinking", model.display_name),
|
||||
..model
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
self.default_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_model)
|
||||
.cloned();
|
||||
self.default_fast_model = models
|
||||
.iter()
|
||||
.find(|model| model.id == response.default_fast_model)
|
||||
.cloned();
|
||||
self.recommended_models = response
|
||||
.recommended_models
|
||||
.iter()
|
||||
.filter_map(|id| models.iter().find(|model| &model.id == id))
|
||||
.cloned()
|
||||
.collect();
|
||||
self.models = models;
|
||||
cx.notify();
|
||||
}
|
||||
|
||||
async fn fetch_models(
|
||||
client: Arc<Client>,
|
||||
llm_api_token: LlmApiToken,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue