assistant: Limit amount of concurrent completion requests (#13856)

This PR refactors the completion providers to only process a maximum
amount of completion requests at a time.

Also started refactoring language model providers to use traits, so it's
easier to allow specifying multiple providers in the future.

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2024-07-05 14:52:45 +02:00 committed by GitHub
parent f2711b2fca
commit c4dbe32f20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 693 additions and 532 deletions

View file

@ -163,7 +163,7 @@ impl LanguageModelRequestMessage {
} }
} }
#[derive(Debug, Default, Serialize)] #[derive(Debug, Default, Serialize, Deserialize)]
pub struct LanguageModelRequest { pub struct LanguageModelRequest {
pub model: LanguageModel, pub model: LanguageModel,
pub messages: Vec<LanguageModelRequestMessage>, pub messages: Vec<LanguageModelRequestMessage>,

View file

@ -1409,7 +1409,7 @@ impl Context {
} }
let request = self.to_completion_request(cx); let request = self.to_completion_request(cx);
let stream = CompletionProvider::global(cx).complete(request); let response = CompletionProvider::global(cx).complete(request, cx);
let assistant_message = self let assistant_message = self
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx) .insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
.unwrap(); .unwrap();
@ -1422,11 +1422,12 @@ impl Context {
let task = cx.spawn({ let task = cx.spawn({
|this, mut cx| async move { |this, mut cx| async move {
let response = response.await;
let assistant_message_id = assistant_message.id; let assistant_message_id = assistant_message.id;
let mut response_latency = None; let mut response_latency = None;
let stream_completion = async { let stream_completion = async {
let request_start = Instant::now(); let request_start = Instant::now();
let mut messages = stream.await?; let mut messages = response.inner.await?;
while let Some(message) = messages.next().await { while let Some(message) = messages.next().await {
if response_latency.is_none() { if response_latency.is_none() {
@ -1718,10 +1719,11 @@ impl Context {
temperature: 1.0, temperature: 1.0,
}; };
let stream = CompletionProvider::global(cx).complete(request); let response = CompletionProvider::global(cx).complete(request, cx);
self.pending_summary = cx.spawn(|this, mut cx| { self.pending_summary = cx.spawn(|this, mut cx| {
async move { async move {
let mut messages = stream.await?; let response = response.await;
let mut messages = response.inner.await?;
while let Some(message) = messages.next().await { while let Some(message) = messages.next().await {
let text = message?; let text = message?;
@ -3642,7 +3644,7 @@ mod tests {
#[gpui::test] #[gpui::test]
fn test_inserting_and_removing_messages(cx: &mut AppContext) { fn test_inserting_and_removing_messages(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); FakeCompletionProvider::setup_test(cx);
cx.set_global(settings_store); cx.set_global(settings_store);
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -3774,7 +3776,7 @@ mod tests {
fn test_message_splitting(cx: &mut AppContext) { fn test_message_splitting(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store); cx.set_global(settings_store);
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); FakeCompletionProvider::setup_test(cx);
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -3867,7 +3869,7 @@ mod tests {
#[gpui::test] #[gpui::test]
fn test_messages_for_offsets(cx: &mut AppContext) { fn test_messages_for_offsets(cx: &mut AppContext) {
let settings_store = SettingsStore::test(cx); let settings_store = SettingsStore::test(cx);
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); FakeCompletionProvider::setup_test(cx);
cx.set_global(settings_store); cx.set_global(settings_store);
init(cx); init(cx);
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone())); let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
@ -3952,7 +3954,8 @@ mod tests {
async fn test_slash_commands(cx: &mut TestAppContext) { async fn test_slash_commands(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test); let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store); cx.set_global(settings_store);
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(Project::init_settings); cx.update(Project::init_settings);
cx.update(init); cx.update(init);
let fs = FakeFs::new(cx.background_executor.clone()); let fs = FakeFs::new(cx.background_executor.clone());
@ -4147,7 +4150,7 @@ mod tests {
async fn test_serialization(cx: &mut TestAppContext) { async fn test_serialization(cx: &mut TestAppContext) {
let settings_store = cx.update(SettingsStore::test); let settings_store = cx.update(SettingsStore::test);
cx.set_global(settings_store); cx.set_global(settings_store);
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default())); cx.update(FakeCompletionProvider::setup_test);
cx.update(init); cx.update(init);
let registry = Arc::new(LanguageRegistry::test(cx.executor())); let registry = Arc::new(LanguageRegistry::test(cx.executor()));
let context = let context =

View file

@ -1,5 +1,6 @@
use std::fmt; use std::fmt;
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
pub use anthropic::Model as AnthropicModel; pub use anthropic::Model as AnthropicModel;
use gpui::Pixels; use gpui::Pixels;
pub use ollama::Model as OllamaModel; pub use ollama::Model as OllamaModel;
@ -15,8 +16,6 @@ use serde::{
use settings::{Settings, SettingsSources}; use settings::{Settings, SettingsSources};
use strum::{EnumIter, IntoEnumIterator}; use strum::{EnumIter, IntoEnumIterator};
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
#[derive(Clone, Debug, Default, PartialEq, EnumIter)] #[derive(Clone, Debug, Default, PartialEq, EnumIter)]
pub enum CloudModel { pub enum CloudModel {
Gpt3Point5Turbo, Gpt3Point5Turbo,

View file

@ -11,6 +11,8 @@ pub use cloud::*;
pub use fake::*; pub use fake::*;
pub use ollama::*; pub use ollama::*;
pub use open_ai::*; pub use open_ai::*;
use parking_lot::RwLock;
use smol::lock::{Semaphore, SemaphoreGuardArc};
use crate::{ use crate::{
assistant_settings::{AssistantProvider, AssistantSettings}, assistant_settings::{AssistantProvider, AssistantSettings},
@ -21,8 +23,8 @@ use client::Client;
use futures::{future::BoxFuture, stream::BoxStream}; use futures::{future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::{any::Any, sync::Arc};
/// Choose which model to use for openai provider. /// Choose which model to use for openai provider.
/// If the model is not available, try to use the first available model, or fallback to the original model. /// If the model is not available, try to use the first available model, or fallback to the original model.
@ -39,176 +41,117 @@ fn choose_openai_model(
} }
pub fn init(client: Arc<Client>, cx: &mut AppContext) { pub fn init(client: Arc<Client>, cx: &mut AppContext) {
let mut settings_version = 0; let provider = create_provider_from_settings(client.clone(), 0, cx);
let provider = match &AssistantSettings::get_global(cx).provider { cx.set_global(CompletionProvider::new(provider, Some(client)));
AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
choose_openai_model(model, available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
)),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => CompletionProvider::Ollama(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
)),
};
cx.set_global(provider);
let mut settings_version = 0;
cx.observe_global::<SettingsStore>(move |cx| { cx.observe_global::<SettingsStore>(move |cx| {
settings_version += 1; settings_version += 1;
cx.update_global::<CompletionProvider, _>(|provider, cx| { cx.update_global::<CompletionProvider, _>(|provider, cx| {
match (&mut *provider, &AssistantSettings::get_global(cx).provider) { provider.update_settings(settings_version, cx);
(
CompletionProvider::OpenAi(provider),
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
) => {
provider.update(
choose_openai_model(model, available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
);
}
(
CompletionProvider::Anthropic(provider),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
);
}
(
CompletionProvider::Ollama(provider),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
);
}
(CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
provider.update(model.clone(), settings_version);
}
(_, AssistantProvider::ZedDotDev { model }) => {
*provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
model.clone(),
client.clone(),
settings_version,
cx,
));
}
(
_,
AssistantProvider::OpenAi {
model,
api_url,
low_speed_timeout_in_seconds,
available_models,
},
) => {
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
choose_openai_model(model, available_models),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
));
}
(
_,
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
));
}
(
_,
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
},
) => {
*provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
model.clone(),
api_url.clone(),
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
));
}
}
}) })
}) })
.detach(); .detach();
} }
pub enum CompletionProvider { pub struct CompletionResponse {
OpenAi(OpenAiCompletionProvider), pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
Anthropic(AnthropicCompletionProvider), _lock: SemaphoreGuardArc,
Cloud(CloudCompletionProvider), }
#[cfg(test)]
Fake(FakeCompletionProvider), pub trait LanguageModelCompletionProvider: Send + Sync {
Ollama(OllamaCompletionProvider), fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize;
fn is_authenticated(&self) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
fn model(&self) -> LanguageModel;
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>>;
fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
pub struct CompletionProvider {
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
client: Option<Arc<Client>>,
request_limiter: Arc<Semaphore>,
}
impl CompletionProvider {
pub fn new(
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
client: Option<Arc<Client>>,
) -> Self {
Self {
provider,
client,
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
}
}
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
self.provider.read().available_models(cx)
}
pub fn settings_version(&self) -> usize {
self.provider.read().settings_version()
}
pub fn is_authenticated(&self) -> bool {
self.provider.read().is_authenticated()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
self.provider.read().authenticate(cx)
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
self.provider.read().authentication_prompt(cx)
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.provider.read().reset_credentials(cx)
}
pub fn model(&self) -> LanguageModel {
self.provider.read().model()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
self.provider.read().count_tokens(request, cx)
}
pub fn complete(
&self,
request: LanguageModelRequest,
cx: &AppContext,
) -> Task<CompletionResponse> {
let rate_limiter = self.request_limiter.clone();
let provider = self.provider.clone();
cx.background_executor().spawn(async move {
let lock = rate_limiter.acquire_arc().await;
let response = provider.read().complete(request);
CompletionResponse {
inner: response,
_lock: lock,
}
})
}
} }
impl gpui::Global for CompletionProvider {} impl gpui::Global for CompletionProvider {}
@ -218,121 +161,213 @@ impl CompletionProvider {
cx.global::<Self>() cx.global::<Self>()
} }
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> { pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
match self { &mut self,
CompletionProvider::OpenAi(provider) => provider update: impl FnOnce(&mut T) -> R,
.available_models(cx) ) -> Option<R> {
.map(LanguageModel::OpenAi) let mut provider = self.provider.write();
.collect(), if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
CompletionProvider::Anthropic(provider) => provider Some(update(provider))
.available_models() } else {
.map(LanguageModel::Anthropic) None
.collect(),
CompletionProvider::Cloud(provider) => provider
.available_models()
.map(LanguageModel::Cloud)
.collect(),
CompletionProvider::Ollama(provider) => provider
.available_models()
.map(|model| LanguageModel::Ollama(model.clone()))
.collect(),
#[cfg(test)]
CompletionProvider::Fake(_) => unimplemented!(),
} }
} }
pub fn settings_version(&self) -> usize { pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
match self { let updated = match &AssistantSettings::get_global(cx).provider {
CompletionProvider::OpenAi(provider) => provider.settings_version(), AssistantProvider::ZedDotDev { model } => self
CompletionProvider::Anthropic(provider) => provider.settings_version(), .update_current_as::<_, CloudCompletionProvider>(|provider| {
CompletionProvider::Cloud(provider) => provider.settings_version(), provider.update(model.clone(), version);
CompletionProvider::Ollama(provider) => provider.settings_version(), }),
#[cfg(test)] AssistantProvider::OpenAi {
CompletionProvider::Fake(_) => unimplemented!(), model,
} api_url,
} low_speed_timeout_in_seconds,
available_models,
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.update(
choose_openai_model(&model, &available_models),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Anthropic {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
);
}),
AssistantProvider::Ollama {
model,
api_url,
low_speed_timeout_in_seconds,
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.update(
model.clone(),
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
version,
cx,
);
}),
};
pub fn is_authenticated(&self) -> bool { // Previously configured provider was changed to another one
match self { if updated.is_none() {
CompletionProvider::OpenAi(provider) => provider.is_authenticated(), if let Some(client) = self.client.clone() {
CompletionProvider::Anthropic(provider) => provider.is_authenticated(), self.provider = create_provider_from_settings(client, version, cx);
CompletionProvider::Cloud(provider) => provider.is_authenticated(), } else {
CompletionProvider::Ollama(provider) => provider.is_authenticated(), log::warn!("completion provider cannot be created because client is not set");
#[cfg(test)]
CompletionProvider::Fake(_) => true,
} }
} }
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { }
match self {
CompletionProvider::OpenAi(provider) => provider.authenticate(cx), fn create_provider_from_settings(
CompletionProvider::Anthropic(provider) => provider.authenticate(cx), client: Arc<Client>,
CompletionProvider::Cloud(provider) => provider.authenticate(cx), settings_version: usize,
CompletionProvider::Ollama(provider) => provider.authenticate(cx), cx: &mut AppContext,
#[cfg(test)] ) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
CompletionProvider::Fake(_) => Task::ready(Ok(())), match &AssistantSettings::get_global(cx).provider {
} AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
} CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
)),
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { AssistantProvider::OpenAi {
match self { model,
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx), api_url,
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx), low_speed_timeout_in_seconds,
CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx), available_models,
CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx), } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
#[cfg(test)] choose_openai_model(&model, &available_models),
CompletionProvider::Fake(_) => unimplemented!(), api_url.clone(),
} client.http_client(),
} low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { ))),
match self { AssistantProvider::Anthropic {
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx), model,
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx), api_url,
CompletionProvider::Cloud(_) => Task::ready(Ok(())), low_speed_timeout_in_seconds,
CompletionProvider::Ollama(provider) => provider.reset_credentials(cx), } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
#[cfg(test)] model.clone(),
CompletionProvider::Fake(_) => Task::ready(Ok(())), api_url.clone(),
} client.http_client(),
} low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
pub fn model(&self) -> LanguageModel { ))),
match self { AssistantProvider::Ollama {
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()), model,
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()), api_url,
CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()), low_speed_timeout_in_seconds,
CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()), } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
#[cfg(test)] model.clone(),
CompletionProvider::Fake(_) => LanguageModel::default(), api_url.clone(),
} client.http_client(),
} low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
pub fn count_tokens( cx,
&self, ))),
request: LanguageModelRequest, }
cx: &AppContext, }
) -> BoxFuture<'static, Result<usize>> {
match self { #[cfg(test)]
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx), mod tests {
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx), use std::sync::Arc;
CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx), use gpui::AppContext;
#[cfg(test)] use parking_lot::RwLock;
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))), use settings::SettingsStore;
} use smol::stream::StreamExt;
}
use crate::{
pub fn complete( completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
&self, FakeCompletionProvider, LanguageModelRequest,
request: LanguageModelRequest, };
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match self { #[gpui::test]
CompletionProvider::OpenAi(provider) => provider.complete(request), fn test_rate_limiting(cx: &mut AppContext) {
CompletionProvider::Anthropic(provider) => provider.complete(request), SettingsStore::test(cx);
CompletionProvider::Cloud(provider) => provider.complete(request), let fake_provider = FakeCompletionProvider::setup_test(cx);
CompletionProvider::Ollama(provider) => provider.complete(request),
#[cfg(test)] let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
CompletionProvider::Fake(provider) => provider.complete(),
} // Enqueue some requests
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
let response = provider.complete(
LanguageModelRequest {
temperature: i as f32 / 10.0,
..Default::default()
},
cx,
);
cx.background_executor()
.spawn(async move {
let response = response.await;
let mut stream = response.inner.await.unwrap();
while let Some(message) = stream.next().await {
message.unwrap();
}
})
.detach();
}
cx.background_executor().run_until_parked();
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Get the first completion request that is in flight and mark it as completed.
let completion = fake_provider
.running_completions()
.into_iter()
.next()
.unwrap();
fake_provider.finish_completion(&completion);
// Ensure that the number of in-flight completion requests is reduced.
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
cx.background_executor().run_until_parked();
// Ensure that another completion request was allowed to acquire the lock.
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS
);
// Mark all completion requests as finished that are in flight.
for request in fake_provider.running_completions() {
fake_provider.finish_completion(&request);
}
assert_eq!(fake_provider.completion_count(), 0);
// Wait until the background tasks acquire the lock again.
cx.background_executor().run_until_parked();
assert_eq!(
fake_provider.completion_count(),
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
);
// Finish all remaining completion requests.
for request in fake_provider.running_completions() {
fake_provider.finish_completion(&request);
}
cx.background_executor().run_until_parked();
assert_eq!(fake_provider.completion_count(), 0);
} }
} }

View file

@ -2,7 +2,7 @@ use crate::{
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
Role, Role,
}; };
use crate::{count_open_ai_tokens, LanguageModelRequestMessage}; use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
use anthropic::{stream_completion, Request, RequestMessage}; use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle}; use editor::{Editor, EditorElement, EditorStyle};
@ -26,50 +26,22 @@ pub struct AnthropicCompletionProvider {
settings_version: usize, settings_version: usize,
} }
impl AnthropicCompletionProvider { impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
pub fn new( fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
model: AnthropicModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
}
}
pub fn update(
&mut self,
model: AnthropicModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
AnthropicModel::iter() AnthropicModel::iter()
.map(LanguageModel::Anthropic)
.collect()
} }
pub fn settings_version(&self) -> usize { fn settings_version(&self) -> usize {
self.settings_version self.settings_version
} }
pub fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
self.api_key.is_some() self.api_key.is_some()
} }
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
@ -85,36 +57,36 @@ impl AnthropicCompletionProvider {
String::from_utf8(api_key)? String::from_utf8(api_key)?
}; };
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::Anthropic(provider) = provider { provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key); provider.api_key = Some(api_key);
} });
}) })
}) })
} }
} }
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url); let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
delete_credentials.await.log_err(); delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::Anthropic(provider) = provider { provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = None; provider.api_key = None;
} });
}) })
}) })
} }
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into() .into()
} }
pub fn model(&self) -> AnthropicModel { fn model(&self) -> LanguageModel {
self.model.clone() LanguageModel::Anthropic(self.model.clone())
} }
pub fn count_tokens( fn count_tokens(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AppContext, cx: &AppContext,
@ -122,7 +94,7 @@ impl AnthropicCompletionProvider {
count_open_ai_tokens(request, cx.background_executor()) count_open_ai_tokens(request, cx.background_executor())
} }
pub fn complete( fn complete(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@ -167,12 +139,48 @@ impl AnthropicCompletionProvider {
.boxed() .boxed()
} }
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl AnthropicCompletionProvider {
pub fn new(
model: AnthropicModel,
api_url: String,
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) -> Self {
Self {
api_key: None,
api_url,
model,
http_client,
low_speed_timeout,
settings_version,
}
}
pub fn update(
&mut self,
model: AnthropicModel,
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
) {
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
self.settings_version = settings_version;
}
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
preprocess_anthropic_request(&mut request); preprocess_anthropic_request(&mut request);
let model = match request.model { let model = match request.model {
LanguageModel::Anthropic(model) => model, LanguageModel::Anthropic(model) => model,
_ => self.model(), _ => self.model.clone(),
}; };
let mut system_message = String::new(); let mut system_message = String::new();
@ -278,9 +286,9 @@ impl AuthenticationPrompt {
cx.spawn(|_, mut cx| async move { cx.spawn(|_, mut cx| async move {
write_credentials.await?; write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::Anthropic(provider) = provider { provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
provider.api_key = Some(api_key); provider.api_key = Some(api_key);
} });
}) })
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
LanguageModelRequest, LanguageModelCompletionProvider, LanguageModelRequest,
}; };
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use client::{proto, Client}; use client::{proto, Client};
@ -30,11 +30,9 @@ impl CloudCompletionProvider {
let maintain_client_status = cx.spawn(|mut cx| async move { let maintain_client_status = cx.spawn(|mut cx| async move {
while let Some(status) = status_rx.next().await { while let Some(status) = status_rx.next().await {
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| { let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::Cloud(provider) = provider { provider.update_current_as::<_, Self>(|provider| {
provider.status = status; provider.status = status;
} else { });
unreachable!()
}
}); });
} }
}); });
@ -51,44 +49,53 @@ impl CloudCompletionProvider {
self.model = model; self.model = model;
self.settings_version = settings_version; self.settings_version = settings_version;
} }
}
pub fn available_models(&self) -> impl Iterator<Item = CloudModel> { impl LanguageModelCompletionProvider for CloudCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model) Some(custom_model)
} else { } else {
None None
}; };
CloudModel::iter().filter_map(move |model| { CloudModel::iter()
.filter_map(move |model| {
if let CloudModel::Custom(_) = model { if let CloudModel::Custom(_) = model {
Some(CloudModel::Custom(custom_model.take()?)) Some(CloudModel::Custom(custom_model.take()?))
} else { } else {
Some(model) Some(model)
} }
}) })
.map(LanguageModel::Cloud)
.collect()
} }
pub fn settings_version(&self) -> usize { fn settings_version(&self) -> usize {
self.settings_version self.settings_version
} }
pub fn model(&self) -> CloudModel { fn is_authenticated(&self) -> bool {
self.model.clone()
}
pub fn is_authenticated(&self) -> bool {
self.status.is_connected() self.status.is_connected()
} }
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
let client = self.client.clone(); let client = self.client.clone();
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await }) cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
} }
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|_cx| AuthenticationPrompt).into() cx.new_view(|_cx| AuthenticationPrompt).into()
} }
pub fn count_tokens( fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model(&self) -> LanguageModel {
LanguageModel::Cloud(self.model.clone())
}
fn count_tokens(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AppContext, cx: &AppContext,
@ -128,7 +135,7 @@ impl CloudCompletionProvider {
} }
} }
pub fn complete( fn complete(
&self, &self,
mut request: LanguageModelRequest, mut request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@ -161,6 +168,10 @@ impl CloudCompletionProvider {
}) })
.boxed() .boxed()
} }
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
} }
struct AuthenticationPrompt; struct AuthenticationPrompt;

View file

@ -1,29 +1,107 @@
use anyhow::Result; use anyhow::Result;
use collections::HashMap;
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, Task};
use std::sync::Arc; use std::sync::Arc;
use ui::WindowContext;
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
#[derive(Clone, Default)] #[derive(Clone, Default)]
pub struct FakeCompletionProvider { pub struct FakeCompletionProvider {
current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>, current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
} }
impl FakeCompletionProvider { impl FakeCompletionProvider {
pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { #[cfg(test)]
let (tx, rx) = mpsc::unbounded(); pub fn setup_test(cx: &mut AppContext) -> Self {
*self.current_completion_tx.lock() = Some(tx); use crate::CompletionProvider;
async move { Ok(rx.map(Ok).boxed()) }.boxed() use parking_lot::RwLock;
let this = Self::default();
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
cx.set_global(provider);
this
} }
pub fn send_completion(&self, chunk: String) { pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
self.current_completion_tx self.current_completion_txs
.lock() .lock()
.as_ref() .keys()
.map(|k| serde_json::from_str(k).unwrap())
.collect()
}
pub fn completion_count(&self) -> usize {
self.current_completion_txs.lock().len()
}
pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
let json = serde_json::to_string(request).unwrap();
self.current_completion_txs
.lock()
.get(&json)
.unwrap() .unwrap()
.unbounded_send(chunk) .unbounded_send(chunk)
.unwrap(); .unwrap();
} }
pub fn finish_completion(&self) { pub fn finish_completion(&self, request: &LanguageModelRequest) {
self.current_completion_tx.lock().take(); self.current_completion_txs
.lock()
.remove(&serde_json::to_string(request).unwrap());
}
}
impl LanguageModelCompletionProvider for FakeCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
vec![LanguageModel::default()]
}
fn settings_version(&self) -> usize {
0
}
fn is_authenticated(&self) -> bool {
true
}
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
unimplemented!()
}
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
fn model(&self) -> LanguageModel {
LanguageModel::default()
}
fn count_tokens(
&self,
_request: LanguageModelRequest,
_cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
futures::future::ready(Ok(0)).boxed()
}
fn complete(
&self,
_request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let (tx, rx) = mpsc::unbounded();
self.current_completion_txs
.lock()
.insert(serde_json::to_string(&_request).unwrap(), tx);
async move { Ok(rx.map(Ok).boxed()) }.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
} }
} }

View file

@ -1,3 +1,4 @@
use crate::LanguageModelCompletionProvider;
use crate::{ use crate::{
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
}; };
@ -26,6 +27,108 @@ pub struct OllamaCompletionProvider {
available_models: Vec<OllamaModel>, available_models: Vec<OllamaModel>,
} }
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
self.available_models
.iter()
.map(|m| LanguageModel::Ollama(m.clone()))
.collect()
}
fn settings_version(&self) -> usize {
self.settings_version
}
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
}
}
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
let fetch_models = Box::new(move |cx: &mut WindowContext| {
cx.update_global::<CompletionProvider, _>(|provider, cx| {
provider
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.fetch_models(cx)
})
.unwrap_or_else(|| Task::ready(Ok(())))
})
});
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
}
fn model(&self) -> LanguageModel {
LanguageModel::Ollama(self.model.clone())
}
fn count_tokens(
&self,
request: LanguageModelRequest,
_cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
// There is no endpoint for this _yet_ in Ollama
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
let token_count = request
.messages
.iter()
.map(|msg| msg.content.chars().count())
.sum::<usize>()
/ 4;
async move { Ok(token_count) }.boxed()
}
fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
}
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
}
.boxed()
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
impl OllamaCompletionProvider { impl OllamaCompletionProvider {
pub fn new( pub fn new(
model: OllamaModel, model: OllamaModel,
@ -87,36 +190,12 @@ impl OllamaCompletionProvider {
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
self.available_models.iter()
}
pub fn select_first_available_model(&mut self) { pub fn select_first_available_model(&mut self) {
if let Some(model) = self.available_models.first() { if let Some(model) = self.available_models.first() {
self.model = model.clone(); self.model = model.clone();
} }
} }
pub fn settings_version(&self) -> usize {
self.settings_version
}
pub fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
}
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
}
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> { pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_url = self.api_url.clone(); let api_url = self.api_url.clone();
@ -137,90 +216,21 @@ impl OllamaCompletionProvider {
models.sort_by(|a, b| a.name.cmp(&b.name)); models.sort_by(|a, b| a.name.cmp(&b.name));
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::Ollama(provider) = provider { provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
provider.available_models = models; provider.available_models = models;
if !provider.available_models.is_empty() && provider.model.name.is_empty() { if !provider.available_models.is_empty() && provider.model.name.is_empty() {
provider.select_first_available_model() provider.select_first_available_model()
} }
}
})
})
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
let fetch_models = Box::new(move |cx: &mut WindowContext| {
cx.update_global::<CompletionProvider, _>(|provider, cx| {
if let CompletionProvider::Ollama(provider) = provider {
provider.fetch_models(cx)
} else {
Task::ready(Ok(()))
}
})
}); });
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
.into()
}
pub fn model(&self) -> OllamaModel {
self.model.clone()
}
pub fn count_tokens(
&self,
request: LanguageModelRequest,
_cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
// There is no endpoint for this _yet_ in Ollama
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
let token_count = request
.messages
.iter()
.map(|msg| msg.content.chars().count())
.sum::<usize>()
/ 4;
async move { Ok(token_count) }.boxed()
}
pub fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
let low_speed_timeout = self.low_speed_timeout;
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
let response = request.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
}
Err(error) => Some(Err(error)),
}
}) })
.boxed(); })
Ok(stream)
}
.boxed()
} }
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest { fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
let model = match request.model { let model = match request.model {
LanguageModel::Ollama(model) => model, LanguageModel::Ollama(model) => model,
_ => self.model(), _ => self.model.clone(),
}; };
ChatRequest { ChatRequest {

View file

@ -1,5 +1,6 @@
use crate::assistant_settings::CloudModel; use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings}; use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::LanguageModelCompletionProvider;
use crate::{ use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
}; };
@ -57,37 +58,75 @@ impl OpenAiCompletionProvider {
self.settings_version = settings_version; self.settings_version = settings_version;
} }
pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> { fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::OpenAi(model) => model,
_ => self.model.clone(),
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
}
}
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi { if let AssistantProvider::OpenAi {
available_models, .. available_models, ..
} = &AssistantSettings::get_global(cx).provider } = &AssistantSettings::get_global(cx).provider
{ {
if !available_models.is_empty() { if !available_models.is_empty() {
// available_models is set, just return it return available_models
return available_models.clone().into_iter(); .iter()
.cloned()
.map(LanguageModel::OpenAi)
.collect();
} }
} }
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
// available_models is not set but the default model is set to custom, only show custom
vec![self.model.clone()] vec![self.model.clone()]
} else { } else {
// default case, use all models except custom
OpenAiModel::iter() OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. })) .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect() .collect()
}; };
available_models.into_iter() available_models
.into_iter()
.map(LanguageModel::OpenAi)
.collect()
} }
pub fn settings_version(&self) -> usize { fn settings_version(&self) -> usize {
self.settings_version self.settings_version
} }
pub fn is_authenticated(&self) -> bool { fn is_authenticated(&self) -> bool {
self.api_key.is_some() self.api_key.is_some()
} }
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> { fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() { if self.is_authenticated() {
Task::ready(Ok(())) Task::ready(Ok(()))
} else { } else {
@ -103,36 +142,36 @@ impl OpenAiCompletionProvider {
String::from_utf8(api_key)? String::from_utf8(api_key)?
}; };
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider { provider.update_current_as::<_, Self>(|provider| {
provider.api_key = Some(api_key); provider.api_key = Some(api_key);
} });
}) })
}) })
} }
} }
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> { fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url); let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move { cx.spawn(|mut cx| async move {
delete_credentials.await.log_err(); delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider { provider.update_current_as::<_, Self>(|provider| {
provider.api_key = None; provider.api_key = None;
} });
}) })
}) })
} }
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView { fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx)) cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into() .into()
} }
pub fn model(&self) -> OpenAiModel { fn model(&self) -> LanguageModel {
self.model.clone() LanguageModel::OpenAi(self.model.clone())
} }
pub fn count_tokens( fn count_tokens(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
cx: &AppContext, cx: &AppContext,
@ -140,7 +179,7 @@ impl OpenAiCompletionProvider {
count_open_ai_tokens(request, cx.background_executor()) count_open_ai_tokens(request, cx.background_executor())
} }
pub fn complete( fn complete(
&self, &self,
request: LanguageModelRequest, request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> { ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@ -173,36 +212,8 @@ impl OpenAiCompletionProvider {
.boxed() .boxed()
} }
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request { fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
let model = match request.model { self
LanguageModel::OpenAi(model) => model,
_ => self.model(),
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
} }
} }
@ -284,9 +295,9 @@ impl AuthenticationPrompt {
cx.spawn(|_, mut cx| async move { cx.spawn(|_, mut cx| async move {
write_credentials.await?; write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| { cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider { provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.api_key = Some(api_key); provider.api_key = Some(api_key);
} });
}) })
}) })
.detach_and_log_err(cx); .detach_and_log_err(cx);

View file

@ -1986,13 +1986,14 @@ impl Codegen {
.unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row))); .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
let model_telemetry_id = prompt.model.telemetry_id(); let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).complete(prompt); let response = CompletionProvider::global(cx).complete(prompt, cx);
let telemetry = self.telemetry.clone(); let telemetry = self.telemetry.clone();
self.edit_position = range.start; self.edit_position = range.start;
self.diff = Diff::default(); self.diff = Diff::default();
self.status = CodegenStatus::Pending; self.status = CodegenStatus::Pending;
self.generation = cx.spawn(|this, mut cx| { self.generation = cx.spawn(|this, mut cx| {
async move { async move {
let response = response.await;
let generate = async { let generate = async {
let mut edit_start = range.start.to_offset(&snapshot); let mut edit_start = range.start.to_offset(&snapshot);
@ -2002,7 +2003,7 @@ impl Codegen {
let mut response_latency = None; let mut response_latency = None;
let request_start = Instant::now(); let request_start = Instant::now();
let diff = async { let diff = async {
let chunks = StripInvalidSpans::new(response.await?); let chunks = StripInvalidSpans::new(response.inner.await?);
futures::pin_mut!(chunks); futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string()); let mut diff = StreamingDiff::new(selected_text.to_string());
@ -2473,9 +2474,8 @@ mod tests {
#[gpui::test(iterations = 10)] #[gpui::test(iterations = 10)]
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) { async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
let provider = FakeCompletionProvider::default();
cx.set_global(cx.update(SettingsStore::test)); cx.set_global(cx.update(SettingsStore::test));
cx.set_global(CompletionProvider::Fake(provider.clone())); let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.update(language_settings::init); cx.update(language_settings::init);
let text = indoc! {" let text = indoc! {"
@ -2495,8 +2495,11 @@ mod tests {
}); });
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx)); let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
let request = LanguageModelRequest::default(); codegen.update(cx, |codegen, cx| {
codegen.update(cx, |codegen, cx| codegen.start(request, cx)); codegen.start(LanguageModelRequest::default(), cx)
});
cx.background_executor.run_until_parked();
let mut new_text = concat!( let mut new_text = concat!(
" let mut x = 0;\n", " let mut x = 0;\n",
@ -2508,11 +2511,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10); let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len); let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len); let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk.into()); provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix; new_text = suffix;
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();
} }
provider.finish_completion(); provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();
assert_eq!( assert_eq!(
@ -2533,8 +2536,7 @@ mod tests {
cx: &mut TestAppContext, cx: &mut TestAppContext,
mut rng: StdRng, mut rng: StdRng,
) { ) {
let provider = FakeCompletionProvider::default(); let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test)); cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init); cx.update(language_settings::init);
@ -2555,6 +2557,8 @@ mod tests {
let request = LanguageModelRequest::default(); let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx)); codegen.update(cx, |codegen, cx| codegen.start(request, cx));
cx.background_executor.run_until_parked();
let mut new_text = concat!( let mut new_text = concat!(
"t mut x = 0;\n", "t mut x = 0;\n",
"while x < 10 {\n", "while x < 10 {\n",
@ -2565,11 +2569,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10); let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len); let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len); let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk.into()); provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix; new_text = suffix;
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();
} }
provider.finish_completion(); provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();
assert_eq!( assert_eq!(
@ -2590,8 +2594,7 @@ mod tests {
cx: &mut TestAppContext, cx: &mut TestAppContext,
mut rng: StdRng, mut rng: StdRng,
) { ) {
let provider = FakeCompletionProvider::default(); let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
cx.set_global(CompletionProvider::Fake(provider.clone()));
cx.set_global(cx.update(SettingsStore::test)); cx.set_global(cx.update(SettingsStore::test));
cx.update(language_settings::init); cx.update(language_settings::init);
@ -2612,6 +2615,8 @@ mod tests {
let request = LanguageModelRequest::default(); let request = LanguageModelRequest::default();
codegen.update(cx, |codegen, cx| codegen.start(request, cx)); codegen.update(cx, |codegen, cx| codegen.start(request, cx));
cx.background_executor.run_until_parked();
let mut new_text = concat!( let mut new_text = concat!(
"let mut x = 0;\n", "let mut x = 0;\n",
"while x < 10 {\n", "while x < 10 {\n",
@ -2622,11 +2627,11 @@ mod tests {
let max_len = cmp::min(new_text.len(), 10); let max_len = cmp::min(new_text.len(), 10);
let len = rng.gen_range(1..=max_len); let len = rng.gen_range(1..=max_len);
let (chunk, suffix) = new_text.split_at(len); let (chunk, suffix) = new_text.split_at(len);
provider.send_completion(chunk.into()); provider.send_completion(&LanguageModelRequest::default(), chunk.into());
new_text = suffix; new_text = suffix;
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();
} }
provider.finish_completion(); provider.finish_completion(&LanguageModelRequest::default());
cx.background_executor.run_until_parked(); cx.background_executor.run_until_parked();
assert_eq!( assert_eq!(

View file

@ -1026,9 +1026,10 @@ impl Codegen {
let telemetry = self.telemetry.clone(); let telemetry = self.telemetry.clone();
let model_telemetry_id = prompt.model.telemetry_id(); let model_telemetry_id = prompt.model.telemetry_id();
let response = CompletionProvider::global(cx).complete(prompt); let response = CompletionProvider::global(cx).complete(prompt, cx);
self.generation = cx.spawn(|this, mut cx| async move { self.generation = cx.spawn(|this, mut cx| async move {
let response = response.await;
let generate = async { let generate = async {
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1); let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
@ -1036,7 +1037,7 @@ impl Codegen {
let mut response_latency = None; let mut response_latency = None;
let request_start = Instant::now(); let request_start = Instant::now();
let task = async { let task = async {
let mut response = response.await?; let mut response = response.inner.await?;
while let Some(chunk) = response.next().await { while let Some(chunk) = response.next().await {
if response_latency.is_none() { if response_latency.is_none() {
response_latency = Some(request_start.elapsed()); response_latency = Some(request_start.elapsed());