From 50482a6bc2c5274634f9f1f2446b05e425a142ad Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 7 Aug 2025 19:00:45 -0400 Subject: [PATCH] language_model: Refresh the LLM token upon receiving a `UserUpdated` message from Cloud (#35839) This PR makes it so we refresh the LLM token upon receiving a `UserUpdated` message from Cloud over the WebSocket connection. Release Notes: - N/A --- Cargo.lock | 1 + crates/client/src/client.rs | 4 +-- crates/language_model/Cargo.toml | 1 + .../language_model/src/model/cloud_model.rs | 34 +++++++++---------- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8c1f1d00ba..39ee75f6dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9127,6 +9127,7 @@ dependencies = [ "anyhow", "base64 0.22.1", "client", + "cloud_api_types", "cloud_llm_client", "collections", "futures 0.3.31", diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 12ea4bcd3e..f09c012a85 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -193,7 +193,7 @@ pub fn init(client: &Arc, cx: &mut App) { }); } -pub type MessageToClientHandler = Box; +pub type MessageToClientHandler = Box; struct GlobalClient(Arc); @@ -1684,7 +1684,7 @@ impl Client { pub fn add_message_to_client_handler( self: &Arc, - handler: impl Fn(&MessageToClient, &App) + Send + Sync + 'static, + handler: impl Fn(&MessageToClient, &mut App) + Send + Sync + 'static, ) { self.message_to_client_handlers .lock() diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml index 841be60b0e..f9920623b5 100644 --- a/crates/language_model/Cargo.toml +++ b/crates/language_model/Cargo.toml @@ -20,6 +20,7 @@ anthropic = { workspace = true, features = ["schemars"] } anyhow.workspace = true base64.workspace = true client.workspace = true +cloud_api_types.workspace = true cloud_llm_client.workspace = true collections.workspace = true futures.workspace = true diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs index 8ae5893410..3b4c1fa269 100644 --- a/crates/language_model/src/model/cloud_model.rs +++ b/crates/language_model/src/model/cloud_model.rs @@ -3,11 +3,9 @@ use std::sync::Arc; use anyhow::Result; use client::Client; +use cloud_api_types::websocket_protocol::MessageToClient; use cloud_llm_client::Plan; -use gpui::{ - App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, ReadGlobal as _, -}; -use proto::TypedEnvelope; +use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Global, ReadGlobal as _}; use smol::lock::{RwLock, RwLockUpgradableReadGuard, RwLockWriteGuard}; use thiserror::Error; @@ -82,9 +80,7 @@ impl Global for GlobalRefreshLlmTokenListener {} pub struct RefreshLlmTokenEvent; -pub struct RefreshLlmTokenListener { - _llm_token_subscription: client::Subscription, -} +pub struct RefreshLlmTokenListener; impl EventEmitter for RefreshLlmTokenListener {} @@ -99,17 +95,21 @@ impl RefreshLlmTokenListener { } fn new(client: Arc, cx: &mut Context) -> Self { - Self { - _llm_token_subscription: client - .add_message_handler(cx.weak_entity(), Self::handle_refresh_llm_token), - } + client.add_message_to_client_handler({ + let this = cx.entity(); + move |message, cx| { + Self::handle_refresh_llm_token(this.clone(), message, cx); + } + }); + + Self } - async fn handle_refresh_llm_token( - this: Entity, - _: TypedEnvelope, - mut cx: AsyncApp, - ) -> Result<()> { - this.update(&mut cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)) + fn handle_refresh_llm_token(this: Entity, message: &MessageToClient, cx: &mut App) { + match message { + MessageToClient::UserUpdated => { + this.update(cx, |_this, cx| cx.emit(RefreshLlmTokenEvent)); + } + } } }