Avoid accidentally taking the api_key when requesting an assist

This commit is contained in:
Antonio Scandurra 2023-06-05 11:25:21 +02:00
parent f00f16fe37
commit bef6932da7
2 changed files with 15 additions and 14 deletions

View file

@ -2,11 +2,9 @@ pub mod assistant;
mod assistant_settings; mod assistant_settings;
pub use assistant::AssistantPanel; pub use assistant::AssistantPanel;
use gpui::{actions, AppContext}; use gpui::AppContext;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
actions!(ai, [Assist]);
// Data types for chat completion requests // Data types for chat completion requests
#[derive(Serialize)] #[derive(Serialize)]
struct OpenAIRequest { struct OpenAIRequest {

View file

@ -16,7 +16,7 @@ use gpui::{
use isahc::{http::StatusCode, Request, RequestExt}; use isahc::{http::StatusCode, Request, RequestExt};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
use settings::SettingsStore; use settings::SettingsStore;
use std::{cell::Cell, io, rc::Rc, sync::Arc, time::Duration}; use std::{cell::RefCell, io, rc::Rc, sync::Arc, time::Duration};
use tiktoken_rs::model::get_context_size; use tiktoken_rs::model::get_context_size;
use util::{post_inc, ResultExt, TryFutureExt}; use util::{post_inc, ResultExt, TryFutureExt};
use workspace::{ use workspace::{
@ -62,7 +62,7 @@ pub struct AssistantPanel {
width: Option<f32>, width: Option<f32>,
height: Option<f32>, height: Option<f32>,
pane: ViewHandle<Pane>, pane: ViewHandle<Pane>,
api_key: Rc<Cell<Option<String>>>, api_key: Rc<RefCell<Option<String>>>,
api_key_editor: Option<ViewHandle<Editor>>, api_key_editor: Option<ViewHandle<Editor>>,
has_read_credentials: bool, has_read_credentials: bool,
languages: Arc<LanguageRegistry>, languages: Arc<LanguageRegistry>,
@ -136,7 +136,7 @@ impl AssistantPanel {
let mut this = Self { let mut this = Self {
pane, pane,
api_key: Rc::new(Cell::new(None)), api_key: Rc::new(RefCell::new(None)),
api_key_editor: None, api_key_editor: None,
has_read_credentials: false, has_read_credentials: false,
languages: workspace.app_state().languages.clone(), languages: workspace.app_state().languages.clone(),
@ -199,7 +199,7 @@ impl AssistantPanel {
cx.platform() cx.platform()
.write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes())
.log_err(); .log_err();
self.api_key.set(Some(api_key)); *self.api_key.borrow_mut() = Some(api_key);
self.api_key_editor.take(); self.api_key_editor.take();
cx.focus_self(); cx.focus_self();
cx.notify(); cx.notify();
@ -333,7 +333,7 @@ impl Panel for AssistantPanel {
fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) { fn set_active(&mut self, active: bool, cx: &mut ViewContext<Self>) {
if active { if active {
if self.api_key.clone().take().is_none() && !self.has_read_credentials { if self.api_key.borrow().is_none() && !self.has_read_credentials {
self.has_read_credentials = true; self.has_read_credentials = true;
let api_key = if let Some((_, api_key)) = cx let api_key = if let Some((_, api_key)) = cx
.platform() .platform()
@ -346,7 +346,7 @@ impl Panel for AssistantPanel {
None None
}; };
if let Some(api_key) = api_key { if let Some(api_key) = api_key {
self.api_key.set(Some(api_key)); *self.api_key.borrow_mut() = Some(api_key);
} else if self.api_key_editor.is_none() { } else if self.api_key_editor.is_none() {
self.api_key_editor = Some(build_api_key_editor(cx)); self.api_key_editor = Some(build_api_key_editor(cx));
cx.notify(); cx.notify();
@ -403,7 +403,7 @@ struct Assistant {
token_count: Option<usize>, token_count: Option<usize>,
max_token_count: usize, max_token_count: usize,
pending_token_count: Task<Option<()>>, pending_token_count: Task<Option<()>>,
api_key: Rc<Cell<Option<String>>>, api_key: Rc<RefCell<Option<String>>>,
_subscriptions: Vec<Subscription>, _subscriptions: Vec<Subscription>,
} }
@ -413,7 +413,7 @@ impl Entity for Assistant {
impl Assistant { impl Assistant {
fn new( fn new(
api_key: Rc<Cell<Option<String>>>, api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ModelContext<Self>, cx: &mut ModelContext<Self>,
) -> Self { ) -> Self {
@ -504,7 +504,8 @@ impl Assistant {
stream: true, stream: true,
}; };
if let Some(api_key) = self.api_key.clone().take() { let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
let stream = stream_completion(api_key, cx.background().clone(), request); let stream = stream_completion(api_key, cx.background().clone(), request);
let response = self.push_message(Role::Assistant, cx); let response = self.push_message(Role::Assistant, cx);
self.push_message(Role::User, cx); self.push_message(Role::User, cx);
@ -600,7 +601,7 @@ struct AssistantEditor {
impl AssistantEditor { impl AssistantEditor {
fn new( fn new(
api_key: Rc<Cell<Option<String>>>, api_key: Rc<RefCell<Option<String>>>,
language_registry: Arc<LanguageRegistry>, language_registry: Arc<LanguageRegistry>,
cx: &mut ViewContext<Self>, cx: &mut ViewContext<Self>,
) -> Self { ) -> Self {
@ -846,7 +847,9 @@ async fn stream_completion(
while let Some(line) = lines.next().await { while let Some(line) = lines.next().await {
if let Some(event) = parse_line(line).transpose() { if let Some(event) = parse_line(line).transpose() {
tx.unbounded_send(event).log_err(); if tx.unbounded_send(event).is_err() {
break;
}
} }
} }