diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index 9d2e8c0142..9b7d4e3f73 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -163,7 +163,7 @@ impl LanguageModelRequestMessage { } } -#[derive(Debug, Default, Serialize)] +#[derive(Debug, Default, Serialize, Deserialize)] pub struct LanguageModelRequest { pub model: LanguageModel, pub messages: Vec, diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index f097338e01..5ba8447087 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -1409,7 +1409,7 @@ impl Context { } let request = self.to_completion_request(cx); - let stream = CompletionProvider::global(cx).complete(request); + let response = CompletionProvider::global(cx).complete(request, cx); let assistant_message = self .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .unwrap(); @@ -1422,11 +1422,12 @@ impl Context { let task = cx.spawn({ |this, mut cx| async move { + let response = response.await; let assistant_message_id = assistant_message.id; let mut response_latency = None; let stream_completion = async { let request_start = Instant::now(); - let mut messages = stream.await?; + let mut messages = response.inner.await?; while let Some(message) = messages.next().await { if response_latency.is_none() { @@ -1718,10 +1719,11 @@ impl Context { temperature: 1.0, }; - let stream = CompletionProvider::global(cx).complete(request); + let response = CompletionProvider::global(cx).complete(request, cx); self.pending_summary = cx.spawn(|this, mut cx| { async move { - let mut messages = stream.await?; + let response = response.await; + let mut messages = response.inner.await?; while let Some(message) = messages.next().await { let text = message?; @@ -3642,7 +3644,7 @@ mod tests { #[gpui::test] fn test_inserting_and_removing_messages(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); - cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); + FakeCompletionProvider::setup_test(cx); cx.set_global(settings_store); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -3774,7 +3776,7 @@ mod tests { fn test_message_splitting(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); cx.set_global(settings_store); - cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); + FakeCompletionProvider::setup_test(cx); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -3867,7 +3869,7 @@ mod tests { #[gpui::test] fn test_messages_for_offsets(cx: &mut AppContext) { let settings_store = SettingsStore::test(cx); - cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); + FakeCompletionProvider::setup_test(cx); cx.set_global(settings_store); init(cx); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); @@ -3952,7 +3954,8 @@ mod tests { async fn test_slash_commands(cx: &mut TestAppContext) { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); + cx.update(|cx| FakeCompletionProvider::setup_test(cx)); + cx.update(Project::init_settings); cx.update(init); let fs = FakeFs::new(cx.background_executor.clone()); @@ -4147,7 +4150,7 @@ mod tests { async fn test_serialization(cx: &mut TestAppContext) { let settings_store = cx.update(SettingsStore::test); cx.set_global(settings_store); - cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); + cx.update(FakeCompletionProvider::setup_test); cx.update(init); let registry = Arc::new(LanguageRegistry::test(cx.executor())); let context = diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index 2d5218fac2..8ff078512a 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,5 +1,6 @@ use std::fmt; +use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; pub use anthropic::Model as AnthropicModel; use gpui::Pixels; pub use ollama::Model as OllamaModel; @@ -15,8 +16,6 @@ use serde::{ use settings::{Settings, SettingsSources}; use strum::{EnumIter, IntoEnumIterator}; -use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; - #[derive(Clone, Debug, Default, PartialEq, EnumIter)] pub enum CloudModel { Gpt3Point5Turbo, diff --git a/crates/assistant/src/completion_provider.rs b/crates/assistant/src/completion_provider.rs index 8badd32d84..36a5bc883e 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/assistant/src/completion_provider.rs @@ -11,6 +11,8 @@ pub use cloud::*; pub use fake::*; pub use ollama::*; pub use open_ai::*; +use parking_lot::RwLock; +use smol::lock::{Semaphore, SemaphoreGuardArc}; use crate::{ assistant_settings::{AssistantProvider, AssistantSettings}, @@ -21,8 +23,8 @@ use client::Client; use futures::{future::BoxFuture, stream::BoxStream}; use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; use settings::{Settings, SettingsStore}; -use std::sync::Arc; use std::time::Duration; +use std::{any::Any, sync::Arc}; /// Choose which model to use for openai provider. /// If the model is not available, try to use the first available model, or fallback to the original model. @@ -39,176 +41,117 @@ fn choose_openai_model( } pub fn init(client: Arc, cx: &mut AppContext) { - let mut settings_version = 0; - let provider = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud( - CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), - ), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => CompletionProvider::OpenAi(OpenAiCompletionProvider::new( - choose_openai_model(model, available_models), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - )), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => CompletionProvider::Anthropic(AnthropicCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - )), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => CompletionProvider::Ollama(OllamaCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - )), - }; - cx.set_global(provider); + let provider = create_provider_from_settings(client.clone(), 0, cx); + cx.set_global(CompletionProvider::new(provider, Some(client))); + let mut settings_version = 0; cx.observe_global::(move |cx| { settings_version += 1; cx.update_global::(|provider, cx| { - match (&mut *provider, &AssistantSettings::get_global(cx).provider) { - ( - CompletionProvider::OpenAi(provider), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - }, - ) => { - provider.update( - choose_openai_model(model, available_models), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ); - } - ( - CompletionProvider::Anthropic(provider), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - }, - ) => { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ); - } - - ( - CompletionProvider::Ollama(provider), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - }, - ) => { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - ); - } - - (CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => { - provider.update(model.clone(), settings_version); - } - (_, AssistantProvider::ZedDotDev { model }) => { - *provider = CompletionProvider::Cloud(CloudCompletionProvider::new( - model.clone(), - client.clone(), - settings_version, - cx, - )); - } - ( - _, - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - }, - ) => { - *provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new( - choose_openai_model(model, available_models), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - )); - } - ( - _, - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - }, - ) => { - *provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - )); - } - ( - _, - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - }, - ) => { - *provider = CompletionProvider::Ollama(OllamaCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - )); - } - } + provider.update_settings(settings_version, cx); }) }) .detach(); } -pub enum CompletionProvider { - OpenAi(OpenAiCompletionProvider), - Anthropic(AnthropicCompletionProvider), - Cloud(CloudCompletionProvider), - #[cfg(test)] - Fake(FakeCompletionProvider), - Ollama(OllamaCompletionProvider), +pub struct CompletionResponse { + pub inner: BoxFuture<'static, Result>>>, + _lock: SemaphoreGuardArc, +} + +pub trait LanguageModelCompletionProvider: Send + Sync { + fn available_models(&self, cx: &AppContext) -> Vec; + fn settings_version(&self) -> usize; + fn is_authenticated(&self) -> bool; + fn authenticate(&self, cx: &AppContext) -> Task>; + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView; + fn reset_credentials(&self, cx: &AppContext) -> Task>; + fn model(&self) -> LanguageModel; + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result>; + fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>>; + + fn as_any_mut(&mut self) -> &mut dyn Any; +} + +const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4; + +pub struct CompletionProvider { + provider: Arc>, + client: Option>, + request_limiter: Arc, +} + +impl CompletionProvider { + pub fn new( + provider: Arc>, + client: Option>, + ) -> Self { + Self { + provider, + client, + request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)), + } + } + + pub fn available_models(&self, cx: &AppContext) -> Vec { + self.provider.read().available_models(cx) + } + + pub fn settings_version(&self) -> usize { + self.provider.read().settings_version() + } + + pub fn is_authenticated(&self) -> bool { + self.provider.read().is_authenticated() + } + + pub fn authenticate(&self, cx: &AppContext) -> Task> { + self.provider.read().authenticate(cx) + } + + pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + self.provider.read().authentication_prompt(cx) + } + + pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + self.provider.read().reset_credentials(cx) + } + + pub fn model(&self) -> LanguageModel { + self.provider.read().model() + } + + pub fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> BoxFuture<'static, Result> { + self.provider.read().count_tokens(request, cx) + } + + pub fn complete( + &self, + request: LanguageModelRequest, + cx: &AppContext, + ) -> Task { + let rate_limiter = self.request_limiter.clone(); + let provider = self.provider.clone(); + cx.background_executor().spawn(async move { + let lock = rate_limiter.acquire_arc().await; + let response = provider.read().complete(request); + CompletionResponse { + inner: response, + _lock: lock, + } + }) + } } impl gpui::Global for CompletionProvider {} @@ -218,121 +161,213 @@ impl CompletionProvider { cx.global::() } - pub fn available_models(&self, cx: &AppContext) -> Vec { - match self { - CompletionProvider::OpenAi(provider) => provider - .available_models(cx) - .map(LanguageModel::OpenAi) - .collect(), - CompletionProvider::Anthropic(provider) => provider - .available_models() - .map(LanguageModel::Anthropic) - .collect(), - CompletionProvider::Cloud(provider) => provider - .available_models() - .map(LanguageModel::Cloud) - .collect(), - CompletionProvider::Ollama(provider) => provider - .available_models() - .map(|model| LanguageModel::Ollama(model.clone())) - .collect(), - #[cfg(test)] - CompletionProvider::Fake(_) => unimplemented!(), + pub fn update_current_as( + &mut self, + update: impl FnOnce(&mut T) -> R, + ) -> Option { + let mut provider = self.provider.write(); + if let Some(provider) = provider.as_any_mut().downcast_mut::() { + Some(update(provider)) + } else { + None } } - pub fn settings_version(&self) -> usize { - match self { - CompletionProvider::OpenAi(provider) => provider.settings_version(), - CompletionProvider::Anthropic(provider) => provider.settings_version(), - CompletionProvider::Cloud(provider) => provider.settings_version(), - CompletionProvider::Ollama(provider) => provider.settings_version(), - #[cfg(test)] - CompletionProvider::Fake(_) => unimplemented!(), - } - } + pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) { + let updated = match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => self + .update_current_as::<_, CloudCompletionProvider>(|provider| { + provider.update(model.clone(), version); + }), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| { + provider.update( + choose_openai_model(&model, &available_models), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + cx, + ); + }), + }; - pub fn is_authenticated(&self) -> bool { - match self { - CompletionProvider::OpenAi(provider) => provider.is_authenticated(), - CompletionProvider::Anthropic(provider) => provider.is_authenticated(), - CompletionProvider::Cloud(provider) => provider.is_authenticated(), - CompletionProvider::Ollama(provider) => provider.is_authenticated(), - #[cfg(test)] - CompletionProvider::Fake(_) => true, - } - } - - pub fn authenticate(&self, cx: &AppContext) -> Task> { - match self { - CompletionProvider::OpenAi(provider) => provider.authenticate(cx), - CompletionProvider::Anthropic(provider) => provider.authenticate(cx), - CompletionProvider::Cloud(provider) => provider.authenticate(cx), - CompletionProvider::Ollama(provider) => provider.authenticate(cx), - #[cfg(test)] - CompletionProvider::Fake(_) => Task::ready(Ok(())), - } - } - - pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - match self { - CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), - CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), - CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx), - CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx), - #[cfg(test)] - CompletionProvider::Fake(_) => unimplemented!(), - } - } - - pub fn reset_credentials(&self, cx: &AppContext) -> Task> { - match self { - CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), - CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), - CompletionProvider::Cloud(_) => Task::ready(Ok(())), - CompletionProvider::Ollama(provider) => provider.reset_credentials(cx), - #[cfg(test)] - CompletionProvider::Fake(_) => Task::ready(Ok(())), - } - } - - pub fn model(&self) -> LanguageModel { - match self { - CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), - CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), - CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()), - CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()), - #[cfg(test)] - CompletionProvider::Fake(_) => LanguageModel::default(), - } - } - - pub fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &AppContext, - ) -> BoxFuture<'static, Result> { - match self { - CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), - CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), - CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx), - CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx), - #[cfg(test)] - CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), - } - } - - pub fn complete( - &self, - request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>> { - match self { - CompletionProvider::OpenAi(provider) => provider.complete(request), - CompletionProvider::Anthropic(provider) => provider.complete(request), - CompletionProvider::Cloud(provider) => provider.complete(request), - CompletionProvider::Ollama(provider) => provider.complete(request), - #[cfg(test)] - CompletionProvider::Fake(provider) => provider.complete(), + // Previously configured provider was changed to another one + if updated.is_none() { + if let Some(client) = self.client.clone() { + self.provider = create_provider_from_settings(client, version, cx); + } else { + log::warn!("completion provider cannot be created because client is not set"); + } } } } + +fn create_provider_from_settings( + client: Arc, + settings_version: usize, + cx: &mut AppContext, +) -> Arc> { + match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( + CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), + )), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( + choose_openai_model(&model, &available_models), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ))), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ))), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + cx, + ))), + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use gpui::AppContext; + use parking_lot::RwLock; + use settings::SettingsStore; + use smol::stream::StreamExt; + + use crate::{ + completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider, + FakeCompletionProvider, LanguageModelRequest, + }; + + #[gpui::test] + fn test_rate_limiting(cx: &mut AppContext) { + SettingsStore::test(cx); + let fake_provider = FakeCompletionProvider::setup_test(cx); + + let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None); + + // Enqueue some requests + for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 { + let response = provider.complete( + LanguageModelRequest { + temperature: i as f32 / 10.0, + ..Default::default() + }, + cx, + ); + cx.background_executor() + .spawn(async move { + let response = response.await; + let mut stream = response.inner.await.unwrap(); + while let Some(message) = stream.next().await { + message.unwrap(); + } + }) + .detach(); + } + cx.background_executor().run_until_parked(); + + assert_eq!( + fake_provider.completion_count(), + MAX_CONCURRENT_COMPLETION_REQUESTS + ); + + // Get the first completion request that is in flight and mark it as completed. + let completion = fake_provider + .running_completions() + .into_iter() + .next() + .unwrap(); + fake_provider.finish_completion(&completion); + + // Ensure that the number of in-flight completion requests is reduced. + assert_eq!( + fake_provider.completion_count(), + MAX_CONCURRENT_COMPLETION_REQUESTS - 1 + ); + + cx.background_executor().run_until_parked(); + + // Ensure that another completion request was allowed to acquire the lock. + assert_eq!( + fake_provider.completion_count(), + MAX_CONCURRENT_COMPLETION_REQUESTS + ); + + // Mark all completion requests as finished that are in flight. + for request in fake_provider.running_completions() { + fake_provider.finish_completion(&request); + } + + assert_eq!(fake_provider.completion_count(), 0); + + // Wait until the background tasks acquire the lock again. + cx.background_executor().run_until_parked(); + + assert_eq!( + fake_provider.completion_count(), + MAX_CONCURRENT_COMPLETION_REQUESTS - 1 + ); + + // Finish all remaining completion requests. + for request in fake_provider.running_completions() { + fake_provider.finish_completion(&request); + } + + cx.background_executor().run_until_parked(); + + assert_eq!(fake_provider.completion_count(), 0); + } +} diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/assistant/src/completion_provider/anthropic.rs index 87236501a9..b4c573588b 100644 --- a/crates/assistant/src/completion_provider/anthropic.rs +++ b/crates/assistant/src/completion_provider/anthropic.rs @@ -2,7 +2,7 @@ use crate::{ assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, }; -use crate::{count_open_ai_tokens, LanguageModelRequestMessage}; +use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage}; use anthropic::{stream_completion, Request, RequestMessage}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; @@ -26,50 +26,22 @@ pub struct AnthropicCompletionProvider { settings_version: usize, } -impl AnthropicCompletionProvider { - pub fn new( - model: AnthropicModel, - api_url: String, - http_client: Arc, - low_speed_timeout: Option, - settings_version: usize, - ) -> Self { - Self { - api_key: None, - api_url, - model, - http_client, - low_speed_timeout, - settings_version, - } - } - - pub fn update( - &mut self, - model: AnthropicModel, - api_url: String, - low_speed_timeout: Option, - settings_version: usize, - ) { - self.model = model; - self.api_url = api_url; - self.low_speed_timeout = low_speed_timeout; - self.settings_version = settings_version; - } - - pub fn available_models(&self) -> impl Iterator { +impl LanguageModelCompletionProvider for AnthropicCompletionProvider { + fn available_models(&self, _cx: &AppContext) -> Vec { AnthropicModel::iter() + .map(LanguageModel::Anthropic) + .collect() } - pub fn settings_version(&self) -> usize { + fn settings_version(&self) -> usize { self.settings_version } - pub fn is_authenticated(&self) -> bool { + fn is_authenticated(&self) -> bool { self.api_key.is_some() } - pub fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &AppContext) -> Task> { if self.is_authenticated() { Task::ready(Ok(())) } else { @@ -85,36 +57,36 @@ impl AnthropicCompletionProvider { String::from_utf8(api_key)? }; cx.update_global::(|provider, _cx| { - if let CompletionProvider::Anthropic(provider) = provider { + provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { provider.api_key = Some(api_key); - } + }); }) }) } } - pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + fn reset_credentials(&self, cx: &AppContext) -> Task> { let delete_credentials = cx.delete_credentials(&self.api_url); cx.spawn(|mut cx| async move { delete_credentials.await.log_err(); cx.update_global::(|provider, _cx| { - if let CompletionProvider::Anthropic(provider) = provider { + provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { provider.api_key = None; - } + }); }) }) } - pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) .into() } - pub fn model(&self) -> AnthropicModel { - self.model.clone() + fn model(&self) -> LanguageModel { + LanguageModel::Anthropic(self.model.clone()) } - pub fn count_tokens( + fn count_tokens( &self, request: LanguageModelRequest, cx: &AppContext, @@ -122,7 +94,7 @@ impl AnthropicCompletionProvider { count_open_ai_tokens(request, cx.background_executor()) } - pub fn complete( + fn complete( &self, request: LanguageModelRequest, ) -> BoxFuture<'static, Result>>> { @@ -167,12 +139,48 @@ impl AnthropicCompletionProvider { .boxed() } + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + +impl AnthropicCompletionProvider { + pub fn new( + model: AnthropicModel, + api_url: String, + http_client: Arc, + low_speed_timeout: Option, + settings_version: usize, + ) -> Self { + Self { + api_key: None, + api_url, + model, + http_client, + low_speed_timeout, + settings_version, + } + } + + pub fn update( + &mut self, + model: AnthropicModel, + api_url: String, + low_speed_timeout: Option, + settings_version: usize, + ) { + self.model = model; + self.api_url = api_url; + self.low_speed_timeout = low_speed_timeout; + self.settings_version = settings_version; + } + fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { preprocess_anthropic_request(&mut request); let model = match request.model { LanguageModel::Anthropic(model) => model, - _ => self.model(), + _ => self.model.clone(), }; let mut system_message = String::new(); @@ -278,9 +286,9 @@ impl AuthenticationPrompt { cx.spawn(|_, mut cx| async move { write_credentials.await?; cx.update_global::(|provider, _cx| { - if let CompletionProvider::Anthropic(provider) = provider { + provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { provider.api_key = Some(api_key); - } + }); }) }) .detach_and_log_err(cx); diff --git a/crates/assistant/src/completion_provider/cloud.rs b/crates/assistant/src/completion_provider/cloud.rs index 1112def519..c02e531ee9 100644 --- a/crates/assistant/src/completion_provider/cloud.rs +++ b/crates/assistant/src/completion_provider/cloud.rs @@ -1,6 +1,6 @@ use crate::{ assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, - LanguageModelRequest, + LanguageModelCompletionProvider, LanguageModelRequest, }; use anyhow::{anyhow, Result}; use client::{proto, Client}; @@ -30,11 +30,9 @@ impl CloudCompletionProvider { let maintain_client_status = cx.spawn(|mut cx| async move { while let Some(status) = status_rx.next().await { let _ = cx.update_global::(|provider, _cx| { - if let CompletionProvider::Cloud(provider) = provider { + provider.update_current_as::<_, Self>(|provider| { provider.status = status; - } else { - unreachable!() - } + }); }); } }); @@ -51,44 +49,53 @@ impl CloudCompletionProvider { self.model = model; self.settings_version = settings_version; } +} - pub fn available_models(&self) -> impl Iterator { +impl LanguageModelCompletionProvider for CloudCompletionProvider { + fn available_models(&self, _cx: &AppContext) -> Vec { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { Some(custom_model) } else { None }; - CloudModel::iter().filter_map(move |model| { - if let CloudModel::Custom(_) = model { - Some(CloudModel::Custom(custom_model.take()?)) - } else { - Some(model) - } - }) + CloudModel::iter() + .filter_map(move |model| { + if let CloudModel::Custom(_) = model { + Some(CloudModel::Custom(custom_model.take()?)) + } else { + Some(model) + } + }) + .map(LanguageModel::Cloud) + .collect() } - pub fn settings_version(&self) -> usize { + fn settings_version(&self) -> usize { self.settings_version } - pub fn model(&self) -> CloudModel { - self.model.clone() - } - - pub fn is_authenticated(&self) -> bool { + fn is_authenticated(&self) -> bool { self.status.is_connected() } - pub fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &AppContext) -> Task> { let client = self.client.clone(); cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) } - pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { cx.new_view(|_cx| AuthenticationPrompt).into() } - pub fn count_tokens( + fn reset_credentials(&self, _cx: &AppContext) -> Task> { + Task::ready(Ok(())) + } + + fn model(&self) -> LanguageModel { + LanguageModel::Cloud(self.model.clone()) + } + + fn count_tokens( &self, request: LanguageModelRequest, cx: &AppContext, @@ -128,7 +135,7 @@ impl CloudCompletionProvider { } } - pub fn complete( + fn complete( &self, mut request: LanguageModelRequest, ) -> BoxFuture<'static, Result>>> { @@ -161,6 +168,10 @@ impl CloudCompletionProvider { }) .boxed() } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } } struct AuthenticationPrompt; diff --git a/crates/assistant/src/completion_provider/fake.rs b/crates/assistant/src/completion_provider/fake.rs index 9c06796a37..f07a3befd2 100644 --- a/crates/assistant/src/completion_provider/fake.rs +++ b/crates/assistant/src/completion_provider/fake.rs @@ -1,29 +1,107 @@ use anyhow::Result; +use collections::HashMap; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; +use gpui::{AnyView, AppContext, Task}; use std::sync::Arc; +use ui::WindowContext; + +use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest}; #[derive(Clone, Default)] pub struct FakeCompletionProvider { - current_completion_tx: Arc>>>, + current_completion_txs: Arc>>>, } impl FakeCompletionProvider { - pub fn complete(&self) -> BoxFuture<'static, Result>>> { - let (tx, rx) = mpsc::unbounded(); - *self.current_completion_tx.lock() = Some(tx); - async move { Ok(rx.map(Ok).boxed()) }.boxed() + #[cfg(test)] + pub fn setup_test(cx: &mut AppContext) -> Self { + use crate::CompletionProvider; + use parking_lot::RwLock; + + let this = Self::default(); + let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None); + cx.set_global(provider); + this } - pub fn send_completion(&self, chunk: String) { - self.current_completion_tx + pub fn running_completions(&self) -> Vec { + self.current_completion_txs .lock() - .as_ref() + .keys() + .map(|k| serde_json::from_str(k).unwrap()) + .collect() + } + + pub fn completion_count(&self) -> usize { + self.current_completion_txs.lock().len() + } + + pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) { + let json = serde_json::to_string(request).unwrap(); + self.current_completion_txs + .lock() + .get(&json) .unwrap() .unbounded_send(chunk) .unwrap(); } - pub fn finish_completion(&self) { - self.current_completion_tx.lock().take(); + pub fn finish_completion(&self, request: &LanguageModelRequest) { + self.current_completion_txs + .lock() + .remove(&serde_json::to_string(request).unwrap()); + } +} + +impl LanguageModelCompletionProvider for FakeCompletionProvider { + fn available_models(&self, _cx: &AppContext) -> Vec { + vec![LanguageModel::default()] + } + + fn settings_version(&self) -> usize { + 0 + } + + fn is_authenticated(&self) -> bool { + true + } + + fn authenticate(&self, _cx: &AppContext) -> Task> { + Task::ready(Ok(())) + } + + fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView { + unimplemented!() + } + + fn reset_credentials(&self, _cx: &AppContext) -> Task> { + Task::ready(Ok(())) + } + + fn model(&self) -> LanguageModel { + LanguageModel::default() + } + + fn count_tokens( + &self, + _request: LanguageModelRequest, + _cx: &AppContext, + ) -> BoxFuture<'static, Result> { + futures::future::ready(Ok(0)).boxed() + } + + fn complete( + &self, + _request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let (tx, rx) = mpsc::unbounded(); + self.current_completion_txs + .lock() + .insert(serde_json::to_string(&_request).unwrap(), tx); + async move { Ok(rx.map(Ok).boxed()) }.boxed() + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self } } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/assistant/src/completion_provider/ollama.rs index e3a80de532..f782a20355 100644 --- a/crates/assistant/src/completion_provider/ollama.rs +++ b/crates/assistant/src/completion_provider/ollama.rs @@ -1,3 +1,4 @@ +use crate::LanguageModelCompletionProvider; use crate::{ assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, }; @@ -26,6 +27,108 @@ pub struct OllamaCompletionProvider { available_models: Vec, } +impl LanguageModelCompletionProvider for OllamaCompletionProvider { + fn available_models(&self, _cx: &AppContext) -> Vec { + self.available_models + .iter() + .map(|m| LanguageModel::Ollama(m.clone())) + .collect() + } + + fn settings_version(&self) -> usize { + self.settings_version + } + + fn is_authenticated(&self) -> bool { + !self.available_models.is_empty() + } + + fn authenticate(&self, cx: &AppContext) -> Task> { + if self.is_authenticated() { + Task::ready(Ok(())) + } else { + self.fetch_models(cx) + } + } + + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + let fetch_models = Box::new(move |cx: &mut WindowContext| { + cx.update_global::(|provider, cx| { + provider + .update_current_as::<_, OllamaCompletionProvider>(|provider| { + provider.fetch_models(cx) + }) + .unwrap_or_else(|| Task::ready(Ok(()))) + }) + }); + + cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx)) + .into() + } + + fn reset_credentials(&self, cx: &AppContext) -> Task> { + self.fetch_models(cx) + } + + fn model(&self) -> LanguageModel { + LanguageModel::Ollama(self.model.clone()) + } + + fn count_tokens( + &self, + request: LanguageModelRequest, + _cx: &AppContext, + ) -> BoxFuture<'static, Result> { + // There is no endpoint for this _yet_ in Ollama + // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 + let token_count = request + .messages + .iter() + .map(|msg| msg.content.chars().count()) + .sum::() + / 4; + + async move { Ok(token_count) }.boxed() + } + + fn complete( + &self, + request: LanguageModelRequest, + ) -> BoxFuture<'static, Result>>> { + let request = self.to_ollama_request(request); + + let http_client = self.http_client.clone(); + let api_url = self.api_url.clone(); + let low_speed_timeout = self.low_speed_timeout; + async move { + let request = + stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); + let response = request.await?; + let stream = response + .filter_map(|response| async move { + match response { + Ok(delta) => { + let content = match delta.message { + ChatMessage::User { content } => content, + ChatMessage::Assistant { content } => content, + ChatMessage::System { content } => content, + }; + Some(Ok(content)) + } + Err(error) => Some(Err(error)), + } + }) + .boxed(); + Ok(stream) + } + .boxed() + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } +} + impl OllamaCompletionProvider { pub fn new( model: OllamaModel, @@ -87,36 +190,12 @@ impl OllamaCompletionProvider { self.settings_version = settings_version; } - pub fn available_models(&self) -> impl Iterator { - self.available_models.iter() - } - pub fn select_first_available_model(&mut self) { if let Some(model) = self.available_models.first() { self.model = model.clone(); } } - pub fn settings_version(&self) -> usize { - self.settings_version - } - - pub fn is_authenticated(&self) -> bool { - !self.available_models.is_empty() - } - - pub fn authenticate(&self, cx: &AppContext) -> Task> { - if self.is_authenticated() { - Task::ready(Ok(())) - } else { - self.fetch_models(cx) - } - } - - pub fn reset_credentials(&self, cx: &AppContext) -> Task> { - self.fetch_models(cx) - } - pub fn fetch_models(&self, cx: &AppContext) -> Task> { let http_client = self.http_client.clone(); let api_url = self.api_url.clone(); @@ -137,90 +216,21 @@ impl OllamaCompletionProvider { models.sort_by(|a, b| a.name.cmp(&b.name)); cx.update_global::(|provider, _cx| { - if let CompletionProvider::Ollama(provider) = provider { + provider.update_current_as::<_, OllamaCompletionProvider>(|provider| { provider.available_models = models; if !provider.available_models.is_empty() && provider.model.name.is_empty() { provider.select_first_available_model() } - } + }); }) }) } - pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { - let fetch_models = Box::new(move |cx: &mut WindowContext| { - cx.update_global::(|provider, cx| { - if let CompletionProvider::Ollama(provider) = provider { - provider.fetch_models(cx) - } else { - Task::ready(Ok(())) - } - }) - }); - - cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx)) - .into() - } - - pub fn model(&self) -> OllamaModel { - self.model.clone() - } - - pub fn count_tokens( - &self, - request: LanguageModelRequest, - _cx: &AppContext, - ) -> BoxFuture<'static, Result> { - // There is no endpoint for this _yet_ in Ollama - // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582 - let token_count = request - .messages - .iter() - .map(|msg| msg.content.chars().count()) - .sum::() - / 4; - - async move { Ok(token_count) }.boxed() - } - - pub fn complete( - &self, - request: LanguageModelRequest, - ) -> BoxFuture<'static, Result>>> { - let request = self.to_ollama_request(request); - - let http_client = self.http_client.clone(); - let api_url = self.api_url.clone(); - let low_speed_timeout = self.low_speed_timeout; - async move { - let request = - stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout); - let response = request.await?; - let stream = response - .filter_map(|response| async move { - match response { - Ok(delta) => { - let content = match delta.message { - ChatMessage::User { content } => content, - ChatMessage::Assistant { content } => content, - ChatMessage::System { content } => content, - }; - Some(Ok(content)) - } - Err(error) => Some(Err(error)), - } - }) - .boxed(); - Ok(stream) - } - .boxed() - } - fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { let model = match request.model { LanguageModel::Ollama(model) => model, - _ => self.model(), + _ => self.model.clone(), }; ChatRequest { diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/assistant/src/completion_provider/open_ai.rs index f4459faf14..6c16e2c9a6 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/assistant/src/completion_provider/open_ai.rs @@ -1,5 +1,6 @@ use crate::assistant_settings::CloudModel; use crate::assistant_settings::{AssistantProvider, AssistantSettings}; +use crate::LanguageModelCompletionProvider; use crate::{ assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, }; @@ -57,37 +58,75 @@ impl OpenAiCompletionProvider { self.settings_version = settings_version; } - pub fn available_models(&self, cx: &AppContext) -> impl Iterator { + fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { + let model = match request.model { + LanguageModel::OpenAi(model) => model, + _ => self.model.clone(), + }; + + Request { + model, + messages: request + .messages + .into_iter() + .map(|msg| match msg.role { + Role::User => RequestMessage::User { + content: msg.content, + }, + Role::Assistant => RequestMessage::Assistant { + content: Some(msg.content), + tool_calls: Vec::new(), + }, + Role::System => RequestMessage::System { + content: msg.content, + }, + }) + .collect(), + stream: true, + stop: request.stop, + temperature: request.temperature, + tools: Vec::new(), + tool_choice: None, + } + } +} + +impl LanguageModelCompletionProvider for OpenAiCompletionProvider { + fn available_models(&self, cx: &AppContext) -> Vec { if let AssistantProvider::OpenAi { available_models, .. } = &AssistantSettings::get_global(cx).provider { if !available_models.is_empty() { - // available_models is set, just return it - return available_models.clone().into_iter(); + return available_models + .iter() + .cloned() + .map(LanguageModel::OpenAi) + .collect(); } } let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { - // available_models is not set but the default model is set to custom, only show custom vec![self.model.clone()] } else { - // default case, use all models except custom OpenAiModel::iter() .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) .collect() }; - available_models.into_iter() + available_models + .into_iter() + .map(LanguageModel::OpenAi) + .collect() } - pub fn settings_version(&self) -> usize { + fn settings_version(&self) -> usize { self.settings_version } - pub fn is_authenticated(&self) -> bool { + fn is_authenticated(&self) -> bool { self.api_key.is_some() } - pub fn authenticate(&self, cx: &AppContext) -> Task> { + fn authenticate(&self, cx: &AppContext) -> Task> { if self.is_authenticated() { Task::ready(Ok(())) } else { @@ -103,36 +142,36 @@ impl OpenAiCompletionProvider { String::from_utf8(api_key)? }; cx.update_global::(|provider, _cx| { - if let CompletionProvider::OpenAi(provider) = provider { + provider.update_current_as::<_, Self>(|provider| { provider.api_key = Some(api_key); - } + }); }) }) } } - pub fn reset_credentials(&self, cx: &AppContext) -> Task> { + fn reset_credentials(&self, cx: &AppContext) -> Task> { let delete_credentials = cx.delete_credentials(&self.api_url); cx.spawn(|mut cx| async move { delete_credentials.await.log_err(); cx.update_global::(|provider, _cx| { - if let CompletionProvider::OpenAi(provider) = provider { + provider.update_current_as::<_, Self>(|provider| { provider.api_key = None; - } + }); }) }) } - pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { + fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) .into() } - pub fn model(&self) -> OpenAiModel { - self.model.clone() + fn model(&self) -> LanguageModel { + LanguageModel::OpenAi(self.model.clone()) } - pub fn count_tokens( + fn count_tokens( &self, request: LanguageModelRequest, cx: &AppContext, @@ -140,7 +179,7 @@ impl OpenAiCompletionProvider { count_open_ai_tokens(request, cx.background_executor()) } - pub fn complete( + fn complete( &self, request: LanguageModelRequest, ) -> BoxFuture<'static, Result>>> { @@ -173,36 +212,8 @@ impl OpenAiCompletionProvider { .boxed() } - fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { - let model = match request.model { - LanguageModel::OpenAi(model) => model, - _ => self.model(), - }; - - Request { - model, - messages: request - .messages - .into_iter() - .map(|msg| match msg.role { - Role::User => RequestMessage::User { - content: msg.content, - }, - Role::Assistant => RequestMessage::Assistant { - content: Some(msg.content), - tool_calls: Vec::new(), - }, - Role::System => RequestMessage::System { - content: msg.content, - }, - }) - .collect(), - stream: true, - stop: request.stop, - temperature: request.temperature, - tools: Vec::new(), - tool_choice: None, - } + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self } } @@ -284,9 +295,9 @@ impl AuthenticationPrompt { cx.spawn(|_, mut cx| async move { write_credentials.await?; cx.update_global::(|provider, _cx| { - if let CompletionProvider::OpenAi(provider) = provider { + provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| { provider.api_key = Some(api_key); - } + }); }) }) .detach_and_log_err(cx); diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index 3c3a799126..4ea5696ca7 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1986,13 +1986,14 @@ impl Codegen { .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row))); let model_telemetry_id = prompt.model.telemetry_id(); - let response = CompletionProvider::global(cx).complete(prompt); + let response = CompletionProvider::global(cx).complete(prompt, cx); let telemetry = self.telemetry.clone(); self.edit_position = range.start; self.diff = Diff::default(); self.status = CodegenStatus::Pending; self.generation = cx.spawn(|this, mut cx| { async move { + let response = response.await; let generate = async { let mut edit_start = range.start.to_offset(&snapshot); @@ -2002,7 +2003,7 @@ impl Codegen { let mut response_latency = None; let request_start = Instant::now(); let diff = async { - let chunks = StripInvalidSpans::new(response.await?); + let chunks = StripInvalidSpans::new(response.inner.await?); futures::pin_mut!(chunks); let mut diff = StreamingDiff::new(selected_text.to_string()); @@ -2473,9 +2474,8 @@ mod tests { #[gpui::test(iterations = 10)] async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { - let provider = FakeCompletionProvider::default(); cx.set_global(cx.update(SettingsStore::test)); - cx.set_global(CompletionProvider::Fake(provider.clone())); + let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx)); cx.update(language_settings::init); let text = indoc! {" @@ -2495,8 +2495,11 @@ mod tests { }); let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx)); - let request = LanguageModelRequest::default(); - codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + codegen.update(cx, |codegen, cx| { + codegen.start(LanguageModelRequest::default(), cx) + }); + + cx.background_executor.run_until_parked(); let mut new_text = concat!( " let mut x = 0;\n", @@ -2508,11 +2511,11 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk.into()); + provider.send_completion(&LanguageModelRequest::default(), chunk.into()); new_text = suffix; cx.background_executor.run_until_parked(); } - provider.finish_completion(); + provider.finish_completion(&LanguageModelRequest::default()); cx.background_executor.run_until_parked(); assert_eq!( @@ -2533,8 +2536,7 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { - let provider = FakeCompletionProvider::default(); - cx.set_global(CompletionProvider::Fake(provider.clone())); + let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx)); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); @@ -2555,6 +2557,8 @@ mod tests { let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + cx.background_executor.run_until_parked(); + let mut new_text = concat!( "t mut x = 0;\n", "while x < 10 {\n", @@ -2565,11 +2569,11 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk.into()); + provider.send_completion(&LanguageModelRequest::default(), chunk.into()); new_text = suffix; cx.background_executor.run_until_parked(); } - provider.finish_completion(); + provider.finish_completion(&LanguageModelRequest::default()); cx.background_executor.run_until_parked(); assert_eq!( @@ -2590,8 +2594,7 @@ mod tests { cx: &mut TestAppContext, mut rng: StdRng, ) { - let provider = FakeCompletionProvider::default(); - cx.set_global(CompletionProvider::Fake(provider.clone())); + let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx)); cx.set_global(cx.update(SettingsStore::test)); cx.update(language_settings::init); @@ -2612,6 +2615,8 @@ mod tests { let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| codegen.start(request, cx)); + cx.background_executor.run_until_parked(); + let mut new_text = concat!( "let mut x = 0;\n", "while x < 10 {\n", @@ -2622,11 +2627,11 @@ mod tests { let max_len = cmp::min(new_text.len(), 10); let len = rng.gen_range(1..=max_len); let (chunk, suffix) = new_text.split_at(len); - provider.send_completion(chunk.into()); + provider.send_completion(&LanguageModelRequest::default(), chunk.into()); new_text = suffix; cx.background_executor.run_until_parked(); } - provider.finish_completion(); + provider.finish_completion(&LanguageModelRequest::default()); cx.background_executor.run_until_parked(); assert_eq!( diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 13c52a29bb..ac0bb9af91 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1026,9 +1026,10 @@ impl Codegen { let telemetry = self.telemetry.clone(); let model_telemetry_id = prompt.model.telemetry_id(); - let response = CompletionProvider::global(cx).complete(prompt); + let response = CompletionProvider::global(cx).complete(prompt, cx); self.generation = cx.spawn(|this, mut cx| async move { + let response = response.await; let generate = async { let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); @@ -1036,7 +1037,7 @@ impl Codegen { let mut response_latency = None; let request_start = Instant::now(); let task = async { - let mut response = response.await?; + let mut response = response.inner.await?; while let Some(chunk) = response.next().await { if response_latency.is_none() { response_latency = Some(request_start.elapsed());