diff --git a/crates/ai/src/completion.rs b/crates/ai/src/completion.rs index 7fdc49e918..30a60fcf1d 100644 --- a/crates/ai/src/completion.rs +++ b/crates/ai/src/completion.rs @@ -13,4 +13,11 @@ pub trait CompletionProvider: CredentialProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>>; + fn box_clone(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.box_clone() + } } diff --git a/crates/ai/src/prompts/base.rs b/crates/ai/src/prompts/base.rs index a2106c7410..75bad00154 100644 --- a/crates/ai/src/prompts/base.rs +++ b/crates/ai/src/prompts/base.rs @@ -147,7 +147,7 @@ pub(crate) mod tests { content = args.model.truncate( &content, max_token_length, - TruncationDirection::Start, + TruncationDirection::End, )?; token_count = max_token_length; } @@ -172,7 +172,7 @@ pub(crate) mod tests { content = args.model.truncate( &content, max_token_length, - TruncationDirection::Start, + TruncationDirection::End, )?; token_count = max_token_length; } diff --git a/crates/ai/src/providers/open_ai/completion.rs b/crates/ai/src/providers/open_ai/completion.rs index 02d25a7eec..94685fd233 100644 --- a/crates/ai/src/providers/open_ai/completion.rs +++ b/crates/ai/src/providers/open_ai/completion.rs @@ -193,6 +193,7 @@ pub async fn stream_completion( } } +#[derive(Clone)] pub struct OpenAICompletionProvider { model: OpenAILanguageModel, credential: Arc>, @@ -271,6 +272,10 @@ impl CompletionProvider for OpenAICompletionProvider { &self, prompt: Box, ) -> BoxFuture<'static, Result>>> { + // Currently the CompletionRequest for OpenAI, includes a 'model' parameter + // This means that the model is determined by the CompletionRequest and not the CompletionProvider, + // which is currently model based, due to the langauge model. + // At some point in the future we should rectify this. let credential = self.credential.read().clone(); let request = stream_completion(credential, self.executor.clone(), prompt); async move { @@ -287,4 +292,7 @@ impl CompletionProvider for OpenAICompletionProvider { } .boxed() } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } } diff --git a/crates/ai/src/test.rs b/crates/ai/src/test.rs index bc9a6a3e43..d4165f3cca 100644 --- a/crates/ai/src/test.rs +++ b/crates/ai/src/test.rs @@ -33,7 +33,10 @@ impl LanguageModel for FakeLanguageModel { length: usize, direction: TruncationDirection, ) -> anyhow::Result { + println!("TRYING TO TRUNCATE: {:?}", length.clone()); + if length > self.count_tokens(content)? { + println!("NOT TRUNCATING"); return anyhow::Ok(content.to_string()); } @@ -133,6 +136,14 @@ pub struct FakeCompletionProvider { last_completion_tx: Mutex>>, } +impl Clone for FakeCompletionProvider { + fn clone(&self) -> Self { + Self { + last_completion_tx: Mutex::new(None), + } + } +} + impl FakeCompletionProvider { pub fn new() -> Self { Self { @@ -174,4 +185,7 @@ impl CompletionProvider for FakeCompletionProvider { *self.last_completion_tx.lock() = Some(tx); async move { Ok(rx.map(|rx| Ok(rx)).boxed()) }.boxed() } + fn box_clone(&self) -> Box { + Box::new((*self).clone()) + } } diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index c10ad2c362..d0c7e7e883 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -9,9 +9,7 @@ use crate::{ use ai::{ auth::ProviderCredential, completion::{CompletionProvider, CompletionRequest}, - providers::open_ai::{ - stream_completion, OpenAICompletionProvider, OpenAIRequest, RequestMessage, - }, + providers::open_ai::{OpenAICompletionProvider, OpenAIRequest, RequestMessage}, }; use ai::prompts::repository_context::PromptCodeSnippet; @@ -47,7 +45,7 @@ use search::BufferSearchBar; use semantic_index::{SemanticIndex, SemanticIndexStatus}; use settings::SettingsStore; use std::{ - cell::{Cell, RefCell}, + cell::Cell, cmp, fmt::Write, iter, @@ -144,10 +142,8 @@ pub struct AssistantPanel { zoomed: bool, has_focus: bool, toolbar: ViewHandle, - credential: Rc>, completion_provider: Box, api_key_editor: Option>, - has_read_credentials: bool, languages: Arc, fs: Arc, subscriptions: Vec, @@ -223,10 +219,8 @@ impl AssistantPanel { zoomed: false, has_focus: false, toolbar, - credential: Rc::new(RefCell::new(ProviderCredential::NoCredentials)), completion_provider, api_key_editor: None, - has_read_credentials: false, languages: workspace.app_state().languages.clone(), fs: workspace.app_state().fs.clone(), width: None, @@ -265,7 +259,7 @@ impl AssistantPanel { cx: &mut ViewContext, ) { let this = if let Some(this) = workspace.panel::(cx) { - if this.update(cx, |assistant, cx| assistant.has_credentials(cx)) { + if this.update(cx, |assistant, _| assistant.has_credentials()) { this } else { workspace.focus_panel::(cx); @@ -331,6 +325,9 @@ impl AssistantPanel { cx.background().clone(), )); + // Retrieve Credentials Authenticates the Provider + // provider.retrieve_credentials(cx); + let codegen = cx.add_model(|cx| { Codegen::new(editor.read(cx).buffer().clone(), codegen_kind, provider, cx) }); @@ -814,7 +811,7 @@ impl AssistantPanel { fn new_conversation(&mut self, cx: &mut ViewContext) -> ViewHandle { let editor = cx.add_view(|cx| { ConversationEditor::new( - self.credential.clone(), + self.completion_provider.clone(), self.languages.clone(), self.fs.clone(), self.workspace.clone(), @@ -883,9 +880,8 @@ impl AssistantPanel { let credential = ProviderCredential::Credentials { api_key: api_key.clone(), }; - self.completion_provider - .save_credentials(cx, credential.clone()); - *self.credential.borrow_mut() = credential; + + self.completion_provider.save_credentials(cx, credential); self.api_key_editor.take(); cx.focus_self(); @@ -898,7 +894,6 @@ impl AssistantPanel { fn reset_credentials(&mut self, _: &ResetKey, cx: &mut ViewContext) { self.completion_provider.delete_credentials(cx); - *self.credential.borrow_mut() = ProviderCredential::NoCredentials; self.api_key_editor = Some(build_api_key_editor(cx)); cx.focus_self(); cx.notify(); @@ -1157,19 +1152,12 @@ impl AssistantPanel { let fs = self.fs.clone(); let workspace = self.workspace.clone(); - let credential = self.credential.clone(); let languages = self.languages.clone(); cx.spawn(|this, mut cx| async move { let saved_conversation = fs.load(&path).await?; let saved_conversation = serde_json::from_str(&saved_conversation)?; let conversation = cx.add_model(|cx| { - Conversation::deserialize( - saved_conversation, - path.clone(), - credential, - languages, - cx, - ) + Conversation::deserialize(saved_conversation, path.clone(), languages, cx) }); this.update(&mut cx, |this, cx| { // If, by the time we've loaded the conversation, the user has already opened @@ -1193,39 +1181,12 @@ impl AssistantPanel { .position(|editor| editor.read(cx).conversation.read(cx).path.as_deref() == Some(path)) } - fn has_credentials(&mut self, cx: &mut ViewContext) -> bool { - let credential = self.load_credentials(cx); - match credential { - ProviderCredential::Credentials { .. } => true, - ProviderCredential::NotNeeded => true, - ProviderCredential::NoCredentials => false, - } + fn has_credentials(&mut self) -> bool { + self.completion_provider.has_credentials() } - fn load_credentials(&mut self, cx: &mut ViewContext) -> ProviderCredential { - let existing_credential = self.credential.clone(); - let existing_credential = existing_credential.borrow().clone(); - match existing_credential { - ProviderCredential::NoCredentials => { - if !self.has_read_credentials { - self.has_read_credentials = true; - let retrieved_credentials = self.completion_provider.retrieve_credentials(cx); - - match retrieved_credentials { - ProviderCredential::NoCredentials {} => { - self.api_key_editor = Some(build_api_key_editor(cx)); - cx.notify(); - } - _ => { - *self.credential.borrow_mut() = retrieved_credentials; - } - } - } - } - _ => {} - } - - self.credential.borrow().clone() + fn load_credentials(&mut self, cx: &mut ViewContext) { + self.completion_provider.retrieve_credentials(cx); } } @@ -1475,10 +1436,10 @@ struct Conversation { token_count: Option, max_token_count: usize, pending_token_count: Task>, - credential: Rc>, pending_save: Task>, path: Option, _subscriptions: Vec, + completion_provider: Box, } impl Entity for Conversation { @@ -1487,10 +1448,9 @@ impl Entity for Conversation { impl Conversation { fn new( - credential: Rc>, - language_registry: Arc, cx: &mut ModelContext, + completion_provider: Box, ) -> Self { let markdown = language_registry.language_for_name("Markdown"); let buffer = cx.add_model(|cx| { @@ -1529,8 +1489,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: None, - credential, buffer, + completion_provider, }; let message = MessageAnchor { id: MessageId(post_inc(&mut this.next_message_id.0)), @@ -1576,7 +1536,6 @@ impl Conversation { fn deserialize( saved_conversation: SavedConversation, path: PathBuf, - credential: Rc>, language_registry: Arc, cx: &mut ModelContext, ) -> Self { @@ -1585,6 +1544,10 @@ impl Conversation { None => Some(Uuid::new_v4().to_string()), }; let model = saved_conversation.model; + let completion_provider: Box = Box::new( + OpenAICompletionProvider::new(model.full_name(), cx.background().clone()), + ); + completion_provider.retrieve_credentials(cx); let markdown = language_registry.language_for_name("Markdown"); let mut message_anchors = Vec::new(); let mut next_message_id = MessageId(0); @@ -1631,8 +1594,8 @@ impl Conversation { _subscriptions: vec![cx.subscribe(&buffer, Self::handle_buffer_event)], pending_save: Task::ready(Ok(())), path: Some(path), - credential, buffer, + completion_provider, }; this.count_remaining_tokens(cx); this @@ -1753,12 +1716,8 @@ impl Conversation { } if should_assist { - let credential = self.credential.borrow().clone(); - match credential { - ProviderCredential::NoCredentials => { - return Default::default(); - } - _ => {} + if !self.completion_provider.has_credentials() { + return Default::default(); } let request: Box = Box::new(OpenAIRequest { @@ -1773,7 +1732,7 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(credential, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1791,33 +1750,28 @@ impl Conversation { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - this.upgrade(&cx) - .ok_or_else(|| anyhow!("conversation was dropped"))? - .update(&mut cx, |this, cx| { - let text: Arc = choice.delta.content?.into(); - let message_ix = - this.message_anchors.iter().position(|message| { - message.id == assistant_message_id - })?; - this.buffer.update(cx, |buffer, cx| { - let offset = this.message_anchors[message_ix + 1..] - .iter() - .find(|message| message.start.is_valid(buffer)) - .map_or(buffer.len(), |message| { - message - .start - .to_offset(buffer) - .saturating_sub(1) - }); - buffer.edit([(offset..offset, text)], None, cx); - }); - cx.emit(ConversationEvent::StreamedCompletion); + let text = message?; - Some(()) + this.upgrade(&cx) + .ok_or_else(|| anyhow!("conversation was dropped"))? + .update(&mut cx, |this, cx| { + let message_ix = this + .message_anchors + .iter() + .position(|message| message.id == assistant_message_id)?; + this.buffer.update(cx, |buffer, cx| { + let offset = this.message_anchors[message_ix + 1..] + .iter() + .find(|message| message.start.is_valid(buffer)) + .map_or(buffer.len(), |message| { + message.start.to_offset(buffer).saturating_sub(1) + }); + buffer.edit([(offset..offset, text)], None, cx); }); - } + cx.emit(ConversationEvent::StreamedCompletion); + + Some(()) + }); smol::future::yield_now().await; } @@ -2039,13 +1993,8 @@ impl Conversation { fn summarize(&mut self, cx: &mut ModelContext) { if self.message_anchors.len() >= 2 && self.summary.is_none() { - let credential = self.credential.borrow().clone(); - - match credential { - ProviderCredential::NoCredentials => { - return; - } - _ => {} + if !self.completion_provider.has_credentials() { + return; } let messages = self @@ -2065,23 +2014,20 @@ impl Conversation { temperature: 1.0, }); - let stream = stream_completion(credential, cx.background().clone(), request); + let stream = self.completion_provider.complete(request); self.pending_summary = cx.spawn(|this, mut cx| { async move { let mut messages = stream.await?; while let Some(message) = messages.next().await { - let mut message = message?; - if let Some(choice) = message.choices.pop() { - let text = choice.delta.content.unwrap_or_default(); - this.update(&mut cx, |this, cx| { - this.summary - .get_or_insert(Default::default()) - .text - .push_str(&text); - cx.emit(ConversationEvent::SummaryChanged); - }); - } + let text = message?; + this.update(&mut cx, |this, cx| { + this.summary + .get_or_insert(Default::default()) + .text + .push_str(&text); + cx.emit(ConversationEvent::SummaryChanged); + }); } this.update(&mut cx, |this, cx| { @@ -2255,13 +2201,14 @@ struct ConversationEditor { impl ConversationEditor { fn new( - credential: Rc>, + completion_provider: Box, language_registry: Arc, fs: Arc, workspace: WeakViewHandle, cx: &mut ViewContext, ) -> Self { - let conversation = cx.add_model(|cx| Conversation::new(credential, language_registry, cx)); + let conversation = + cx.add_model(|cx| Conversation::new(language_registry, cx, completion_provider)); Self::for_conversation(conversation, fs, workspace, cx) } @@ -3450,6 +3397,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { mod tests { use super::*; use crate::MessageId; + use ai::test::FakeCompletionProvider; use gpui::AppContext; #[gpui::test] @@ -3457,13 +3405,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3591,13 +3535,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3693,13 +3633,8 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry, - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = cx.add_model(|cx| Conversation::new(registry, cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_1 = conversation.read(cx).message_anchors[0].clone(); @@ -3781,13 +3716,9 @@ mod tests { cx.set_global(SettingsStore::test(cx)); init(cx); let registry = Arc::new(LanguageRegistry::test()); - let conversation = cx.add_model(|cx| { - Conversation::new( - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), - registry.clone(), - cx, - ) - }); + let completion_provider = Box::new(FakeCompletionProvider::new()); + let conversation = + cx.add_model(|cx| Conversation::new(registry.clone(), cx, completion_provider)); let buffer = conversation.read(cx).buffer.clone(); let message_0 = conversation.read(cx).message_anchors[0].id; let message_1 = conversation.update(cx, |conversation, cx| { @@ -3824,7 +3755,6 @@ mod tests { Conversation::deserialize( conversation.read(cx).serialize(cx), Default::default(), - Rc::new(RefCell::new(ProviderCredential::NotNeeded)), registry.clone(), cx, )