diff --git a/crates/ai/src/auth.rs b/crates/ai/src/auth.rs index c188c30797..cb3f2beabb 100644 --- a/crates/ai/src/auth.rs +++ b/crates/ai/src/auth.rs @@ -9,6 +9,8 @@ pub enum ProviderCredential { pub trait CredentialProvider: Send + Sync { fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential; + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential); + fn delete_credentials(&self, cx: &AppContext); } #[derive(Clone)] @@ -17,4 +19,6 @@ impl CredentialProvider for NullCredentialProvider { fn retrieve_credentials(&self, _cx: &AppContext) -> ProviderCredential { ProviderCredential::NotNeeded } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) {} + fn delete_credentials(&self, cx: &AppContext) {} } diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 5b9bad4870..6a2806a5cb 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -17,6 +17,12 @@ pub trait CompletionProvider { fn retrieve_credentials(&self, cx: &AppContext) -> ProviderCredential { self.credential_provider().retrieve_credentials(cx) } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + self.credential_provider().save_credentials(cx, credential); + } + fn delete_credentials(&self, cx: &AppContext) { + self.credential_provider().delete_credentials(cx); + } fn complete( &self, prompt: Box, diff --git a/crates/ai/src/providers/open_ai/auth.rs b/crates/ai/src/providers/open_ai/auth.rs index c817ffea00..7cb51ab449 100644 --- a/crates/ai/src/providers/open_ai/auth.rs +++ b/crates/ai/src/providers/open_ai/auth.rs @@ -30,4 +30,17 @@ impl CredentialProvider for OpenAICredentialProvider { ProviderCredential::NoCredentials } } + fn save_credentials(&self, cx: &AppContext, credential: ProviderCredential) { + match credential { + ProviderCredential::Credentials { api_key } => { + cx.platform() + .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) + .log_err(); + } + _ => {} + } + } + fn delete_credentials(&self, cx: &AppContext) { + cx.platform().delete_credentials(OPENAI_API_URL).log_err(); + } } diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 9c9d205ff7..febe491123 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -13,7 +13,7 @@ use std::{ }; use crate::{ - auth::CredentialProvider, + auth::{CredentialProvider, ProviderCredential}, completion::{CompletionProvider, CompletionRequest}, models::LanguageModel, }; @@ -102,10 +102,17 @@ pub struct OpenAIResponseStreamEvent { } pub async fn stream_completion( - api_key: String, + credential: ProviderCredential, executor: Arc, request: Box, ) -> Result>> { + let api_key = match credential { + ProviderCredential::Credentials { api_key } => api_key, + _ => { + return Err(anyhow!("no credentials provider for completion")); + } + }; + let (tx, rx) = futures::channel::mpsc::unbounded::>(); let json_data = request.data()?; @@ -188,18 +195,22 @@ pub async fn stream_completion( pub struct OpenAICompletionProvider { model: OpenAILanguageModel, credential_provider: OpenAICredentialProvider, - api_key: String, + credential: ProviderCredential, executor: Arc, } impl OpenAICompletionProvider { - pub fn new(model_name: &str, api_key: String, executor: Arc) -> Self { + pub fn new( + model_name: &str, + credential: ProviderCredential, + executor: Arc, + ) -> Self { let model = OpenAILanguageModel::load(model_name); let credential_provider = OpenAICredentialProvider {}; Self { model, credential_provider, - api_key, + credential, executor, } } @@ -218,7 +229,8 @@ impl CompletionProvider for OpenAICompletionProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { - let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt); + let credential = self.credential.clone(); + let request = stream_completion(credential, self.executor.clone(), prompt); async move { let response = request.await?; let stream = response diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c899465ed2..f9187b8785 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -7,7 +7,8 @@ use crate::{ }; use ai::{ - completion::CompletionRequest, + auth::ProviderCredential, + completion::{CompletionProvider, CompletionRequest}, providers::open_ai::{ stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, OPENAI_API_URL, }, @@ -100,8 +101,8 @@ pub fn init(cx: &mut AppContext) { cx.capture_action(ConversationEditor::copy); cx.add_action(ConversationEditor::split); cx.capture_action(ConversationEditor::cycle_message_role); - cx.add_action(AssistantPanel::save_api_key); - cx.add_action(AssistantPanel::reset_api_key); + cx.add_action(AssistantPanel::save_credentials); + cx.add_action(AssistantPanel::reset_credentials); cx.add_action(AssistantPanel::toggle_zoom); cx.add_action(AssistantPanel::deploy); cx.add_action(AssistantPanel::select_next_match); @@ -143,7 +144,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - api_key: Rc>>, + credential: Rc>, + completion_provider: Box, api_key_editor: Option>, has_read_credentials: bool, languages: Arc, @@ -205,6 +207,12 @@ impl AssistantPanel { }); let semantic_index = SemanticIndex::global(cx); + // Defaulting currently to GPT4, allow for this to be set via config. + let completion_provider = Box::new(OpenAICompletionProvider::new( + "gpt-4", + ProviderCredential::NoCredentials, + cx.background().clone(), + )); let mut this = Self { workspace: workspace_handle, @@ -216,7 +224,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - api_key: Rc::new(RefCell::new(None)), + credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)), + completion_provider, api_key_editor: None, has_read_credentials: false, languages: workspace.app_state().languages.clone(), @@ -257,10 +266,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this - .update(cx, |assistant, cx| assistant.load_api_key(cx)) - .is_some() - { + if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) { this } else { workspace.focus_panel::(cx); @@ -292,12 +298,7 @@ impl AssistantPanel { cx: &mut ViewContext, project: &ModelHandle, ) { - let api_key = if let Some(api_key) = self.api_key.borrow().clone() { - api_key - } else { - return; - }; - + let credential = self.credential.borrow().clone(); let selection = editor.read(cx).selections.newest_anchor().clone(); if selection.start.excerpt_id() != selection.end.excerpt_id() { return; @@ -329,7 +330,7 @@ impl AssistantPanel { let inline_assist_id = post_inc(&mut self.next_inline_assist_id); let provider = Arc::new(OpenAICompletionProvider::new( "gpt-4", - api_key, + credential, cx.background().clone(), )); @@ -816,7 +817,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.api_key.clone(), + self.credential.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -875,17 +876,20 @@ impl AssistantPanel { } } - fn save_api_key(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { + fn save_credentials(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { if let Some(api_key) = self .api_key_editor .as_ref() .map(|editor| editor.read(cx).text(cx)) { if !api_key.is_empty() { - cx.platform() - .write_credentials(OPENAI_API_URL, "Bearer", api_key.as_bytes()) - .log_err(); - *self.api_key.borrow_mut() = Some(api_key); + let credential = ProviderCredential::Credentials { + api_key: api_key.clone(), + }; + self.completion_provider + .save_credentials(cx, credential.clone()); + *self.credential.borrow_mut() = credential; + self.api_key_editor.take(); cx.focus_self(); cx.notify(); @@ -895,9 +899,9 @@ impl AssistantPanel { } } - fn reset_api_key(&mut self, _: &ResetKey, cx: &mut ViewContext) { - cx.platform().delete_credentials(OPENAI_API_URL).log_err(); - self.api_key.take(); + fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { + self.completion_provider.delete_credentials(cx); + *self.credential.borrow_mut() = ProviderCredential::NoCredentials; self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1156,13 +1160,19 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let api_key = self.api_key.clone(); + let credential = self.credential.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize(saved_conversation, path.clone(), api_key, languages, cx) + Conversation::deserialize( + saved_conversation, + path.clone(), + credential, + languages, + cx, + ) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1186,30 +1196,39 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn load_api_key(&mut self, cx: &mut ViewContext) -> Option { - if self.api_key.borrow().is_none() && !self.has_read_credentials { - self.has_read_credentials = true; - let api_key = if let Ok(api_key) = env::var("OPENAI_API_KEY") { - Some(api_key) - } else if let Some((_, api_key)) = cx - .platform() - .read_credentials(OPENAI_API_URL) - .log_err() - .flatten() - { - String::from_utf8(api_key).log_err() - } else { - None - }; - if let Some(api_key) = api_key { - *self.api_key.borrow_mut() = Some(api_key); - } else if self.api_key_editor.is_none() { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); + fn has_credentials(&mut self, cx: &mut ViewContext) -> bool { + let credential = self.load_credentials(cx); + match credential { + ProviderCredential::Credentials { .. } => true, + ProviderCredential::NotNeeded => true, + ProviderCredential::NoCredentials => false, + } + } + + fn load_credentials(&mut self, cx: &mut ViewContext) -> ProviderCredential { + let existing_credential = self.credential.clone(); + let existing_credential = existing_credential.borrow().clone(); + match existing_credential { + ProviderCredential::NoCredentials => { + if !self.has_read_credentials { + self.has_read_credentials = true; + let retrieved_credentials = self.completion_provider.retrieve_credentials(cx); + + match retrieved_credentials { + ProviderCredential::NoCredentials {} => { + self.api_key_editor = Some(build_api_key_editor(cx)); + cx.notify(); + } + _ => { + *self.credential.borrow_mut() = retrieved_credentials; + } + } + } } + _ => {} } - self.api_key.borrow().clone() + self.credential.borrow().clone() } } @@ -1394,7 +1413,7 @@ impl Panel for AssistantPanel { fn set_active(&mut self, active: bool, cx: &mut ViewContext) { if active { - self.load_api_key(cx); + self.load_credentials(cx); if self.editors.is_empty() { self.new_conversation(cx); @@ -1459,7 +1478,7 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - api_key: Rc>>, + credential: Rc>, pending_save: Task>, path: Option, _subscriptions: Vec, @@ -1471,7 +1490,8 @@ impl Entity for Conversation { impl Conversation { fn new( - api_key: Rc>>, + credential: Rc>, + language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1512,7 +1532,7 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - api_key, + credential, buffer, }; let message = MessageAnchor { @@ -1559,7 +1579,7 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - api_key: Rc>>, + credential: Rc>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1614,7 +1634,7 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - api_key, + credential, buffer, }; this.count_remaining_tokens(cx); @@ -1736,9 +1756,13 @@ impl Conversation { } if should_assist { - let Some(api_key) = self.api_key.borrow().clone() else { - return Default::default(); - }; + let credential = self.credential.borrow().clone(); + match credential { + ProviderCredential::NoCredentials => { + return Default::default(); + } + _ => {} + } let request: Box = Box::new(OpenAIRequest { model: self.model.full_name().to_string(), @@ -1752,7 +1776,7 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(api_key, cx.background().clone(), request); + let stream = stream_completion(credential, cx.background().clone(), request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -2018,57 +2042,62 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let api_key = self.api_key.borrow().clone(); - if let Some(api_key) = api_key { - let messages = self - .messages(cx) - .take(2) - .map(|message| message.to_open_ai_message(self.buffer.read(cx))) - .chain(Some(RequestMessage { - role: Role::User, - content: - "Summarize the conversation into a short title without punctuation" - .into(), - })); - let request: Box = Box::new(OpenAIRequest { - model: self.model.full_name().to_string(), - messages: messages.collect(), - stream: true, - stop: vec![], - temperature: 1.0, - }); + let credential = self.credential.borrow().clone(); - let stream = stream_completion(api_key, cx.background().clone(), request); - self.pending_summary = cx.spawn(|this, mut cx| { - async move { - let mut messages = stream.await?; - - while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } - } - - this.update(&mut cx, |this, cx| { - if let Some(summary) = this.summary.as_mut() { - summary.done = true; - cx.emit(ConversationEvent::SummaryChanged); - } - }); - - anyhow::Ok(()) - } - .log_err() - }); + match credential { + ProviderCredential::NoCredentials => { + return; + } + _ => {} } + + let messages = self + .messages(cx) + .take(2) + .map(|message| message.to_open_ai_message(self.buffer.read(cx))) + .chain(Some(RequestMessage { + role: Role::User, + content: "Summarize the conversation into a short title without punctuation" + .into(), + })); + let request: Box = Box::new(OpenAIRequest { + model: self.model.full_name().to_string(), + messages: messages.collect(), + stream: true, + stop: vec![], + temperature: 1.0, + }); + + let stream = stream_completion(credential, cx.background().clone(), request); + self.pending_summary = cx.spawn(|this, mut cx| { + async move { + let mut messages = stream.await?; + + while let Some(message) = messages.next().await { + let mut message = message?; + if let Some(choice) = message.choices.pop() { + let text = choice.delta.content.unwrap_or_default(); + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); + } + } + + this.update(&mut cx, |this, cx| { + if let Some(summary) = this.summary.as_mut() { + summary.done = true; + cx.emit(ConversationEvent::SummaryChanged); + } + }); + + anyhow::Ok(()) + } + .log_err() + }); } } @@ -2229,13 +2258,13 @@ struct ConversationEditor { impl ConversationEditor { fn new( - api_key: Rc>>, + credential: Rc>, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(api_key, language_registry, cx)); + let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3431,7 +3460,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3559,7 +3594,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3655,7 +3696,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| Conversation::new(Default::default(), registry, cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry, + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3737,8 +3784,13 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = - cx.add_model(|cx| Conversation::new(Default::default(), registry.clone(), cx)); + let conversation = cx.add_model(|cx| { + Conversation::new( + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), + registry.clone(), + cx, + ) + }); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3775,7 +3827,7 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Default::default(), + Rc::new(RefCell::new(ProviderCredential::NotNeeded)), registry.clone(), cx, )