From bef6932da794877ab1011d1b3cacf28a9e4d17d0 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Jun 2023 11:25:21 +0200 Subject: [PATCH] Avoid accidentally taking the `api_key` when requesting an assist --- crates/ai/src/ai.rs | 4 +--- crates/ai/src/assistant.rs | 25 ++++++++++++++----------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/crates/ai/src/ai.rs b/crates/ai/src/ai.rs index a5d5666a72..11704de03e 100644 --- a/crates/ai/src/ai.rs +++ b/crates/ai/src/ai.rs @@ -2,11 +2,9 @@ pub mod assistant; mod assistant_settings; pub use assistant::AssistantPanel; -use gpui::{actions, AppContext}; +use gpui::AppContext; use serde::{Deserialize, Serialize}; -actions!(ai, [Assist]); - // Data types for chat completion requests #[derive(Serialize)] struct OpenAIRequest { diff --git a/crates/ai/src/assistant.rs b/crates/ai/src/assistant.rs index 68f722f1ee..7020f9f38c 100644 --- a/crates/ai/src/assistant.rs +++ b/crates/ai/src/assistant.rs @@ -16,7 +16,7 @@ use gpui::{ use isahc::{http::StatusCode, Request, RequestExt}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use settings::SettingsStore; -use std::{cell::Cell, io, rc::Rc, sync::Arc, time::Duration}; +use std::{cell::RefCell, io, rc::Rc, sync::Arc, time::Duration}; use tiktoken_rs::model::get_context_size; use util::{post_inc, ResultExt, TryFutureExt}; use workspace::{ @@ -62,7 +62,7 @@ pub struct AssistantPanel { width: Option, height: Option, pane: ViewHandle, - api_key: Rc>>, + api_key: Rc>>, api_key_editor: Option>, has_read_credentials: bool, languages: Arc, @@ -136,7 +136,7 @@ impl AssistantPanel { let mut this = Self { pane, - api_key: Rc::new(Cell::new(None)), + api_key: Rc::new(RefCell::new(None)), api_key_editor: None, has_read_credentials: false, languages: workspace.app_state().languages.clone(), @@ -199,7 +199,7 @@ impl AssistantPanel { cx.platform() .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) .log_err(); - self.api_key.set(Some(api_key)); + *self.api_key.borrow_mut() = Some(api_key); self.api_key_editor.take(); cx.focus_self(); cx.notify(); @@ -333,7 +333,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - if self.api_key.clone().take().is_none() && !self.has_read_credentials { + if self.api_key.borrow().is_none() && !self.has_read_credentials { self.has_read_credentials = true; let api_key = if let Some((_, api_key)) = cx .platform() @@ -346,7 +346,7 @@ impl Panel for AssistantPanel { None }; if let Some(api_key) = api_key { - self.api_key.set(Some(api_key)); + *self.api_key.borrow_mut() = Some(api_key); } else if self.api_key_editor.is_none() { self.api_key_editor = Some(build_api_key_editor(cx)); cx.notify(); @@ -403,7 +403,7 @@ struct Assistant { token_count: Option, max_token_count: usize, pending_token_count: Task>, - api_key: Rc>>, + api_key: Rc>>, _subscriptions: Vec, } @@ -413,7 +413,7 @@ impl Entity for Assistant { impl Assistant { fn new( - api_key: Rc>>, + api_key: Rc>>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -504,7 +504,8 @@ impl Assistant { stream: true, }; - if let Some(api_key) = self.api_key.clone().take() { + let api_key = self.api_key.borrow().clone(); + if let Some(api_key) = api_key { let stream = stream_completion(api_key, cx.background().clone(), request); let response = self.push_message(Role::Assistant, cx); self.push_message(Role::User, cx); @@ -600,7 +601,7 @@ struct AssistantEditor { impl AssistantEditor { fn new( - api_key: Rc>>, + api_key: Rc>>, language_registry: Arc, cx: &mut ViewContext, ) -> Self { @@ -846,7 +847,9 @@ async fn stream_completion( while let Some(line) = lines.next().await { if let Some(event) = parse_line(line).transpose() { - tx.unbounded_send(event).log_err(); + if tx.unbounded_send(event).is_err() { + break; + } } }