assistant: Limit amount of concurrent completion requests (#13856)
This PR refactors the completion providers to only process a maximum amount of completion requests at a time. Also started refactoring language model providers to use traits, so it's easier to allow specifying multiple providers in the future. Release Notes: - N/A
This commit is contained in:
parent
f2711b2fca
commit
c4dbe32f20
11 changed files with 693 additions and 532 deletions
|
@ -163,7 +163,7 @@ impl LanguageModelRequestMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default, Serialize)]
|
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||||
pub struct LanguageModelRequest {
|
pub struct LanguageModelRequest {
|
||||||
pub model: LanguageModel,
|
pub model: LanguageModel,
|
||||||
pub messages: Vec<LanguageModelRequestMessage>,
|
pub messages: Vec<LanguageModelRequestMessage>,
|
||||||
|
|
|
@ -1409,7 +1409,7 @@ impl Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
let request = self.to_completion_request(cx);
|
let request = self.to_completion_request(cx);
|
||||||
let stream = CompletionProvider::global(cx).complete(request);
|
let response = CompletionProvider::global(cx).complete(request, cx);
|
||||||
let assistant_message = self
|
let assistant_message = self
|
||||||
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
.insert_message_after(last_message_id, Role::Assistant, MessageStatus::Pending, cx)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -1422,11 +1422,12 @@ impl Context {
|
||||||
|
|
||||||
let task = cx.spawn({
|
let task = cx.spawn({
|
||||||
|this, mut cx| async move {
|
|this, mut cx| async move {
|
||||||
|
let response = response.await;
|
||||||
let assistant_message_id = assistant_message.id;
|
let assistant_message_id = assistant_message.id;
|
||||||
let mut response_latency = None;
|
let mut response_latency = None;
|
||||||
let stream_completion = async {
|
let stream_completion = async {
|
||||||
let request_start = Instant::now();
|
let request_start = Instant::now();
|
||||||
let mut messages = stream.await?;
|
let mut messages = response.inner.await?;
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
while let Some(message) = messages.next().await {
|
||||||
if response_latency.is_none() {
|
if response_latency.is_none() {
|
||||||
|
@ -1718,10 +1719,11 @@ impl Context {
|
||||||
temperature: 1.0,
|
temperature: 1.0,
|
||||||
};
|
};
|
||||||
|
|
||||||
let stream = CompletionProvider::global(cx).complete(request);
|
let response = CompletionProvider::global(cx).complete(request, cx);
|
||||||
self.pending_summary = cx.spawn(|this, mut cx| {
|
self.pending_summary = cx.spawn(|this, mut cx| {
|
||||||
async move {
|
async move {
|
||||||
let mut messages = stream.await?;
|
let response = response.await;
|
||||||
|
let mut messages = response.inner.await?;
|
||||||
|
|
||||||
while let Some(message) = messages.next().await {
|
while let Some(message) = messages.next().await {
|
||||||
let text = message?;
|
let text = message?;
|
||||||
|
@ -3642,7 +3644,7 @@ mod tests {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
fn test_inserting_and_removing_messages(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
FakeCompletionProvider::setup_test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
@ -3774,7 +3776,7 @@ mod tests {
|
||||||
fn test_message_splitting(cx: &mut AppContext) {
|
fn test_message_splitting(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
FakeCompletionProvider::setup_test(cx);
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
|
||||||
|
@ -3867,7 +3869,7 @@ mod tests {
|
||||||
#[gpui::test]
|
#[gpui::test]
|
||||||
fn test_messages_for_offsets(cx: &mut AppContext) {
|
fn test_messages_for_offsets(cx: &mut AppContext) {
|
||||||
let settings_store = SettingsStore::test(cx);
|
let settings_store = SettingsStore::test(cx);
|
||||||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
FakeCompletionProvider::setup_test(cx);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
init(cx);
|
init(cx);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
let registry = Arc::new(LanguageRegistry::test(cx.background_executor().clone()));
|
||||||
|
@ -3952,7 +3954,8 @@ mod tests {
|
||||||
async fn test_slash_commands(cx: &mut TestAppContext) {
|
async fn test_slash_commands(cx: &mut TestAppContext) {
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
||||||
|
|
||||||
cx.update(Project::init_settings);
|
cx.update(Project::init_settings);
|
||||||
cx.update(init);
|
cx.update(init);
|
||||||
let fs = FakeFs::new(cx.background_executor.clone());
|
let fs = FakeFs::new(cx.background_executor.clone());
|
||||||
|
@ -4147,7 +4150,7 @@ mod tests {
|
||||||
async fn test_serialization(cx: &mut TestAppContext) {
|
async fn test_serialization(cx: &mut TestAppContext) {
|
||||||
let settings_store = cx.update(SettingsStore::test);
|
let settings_store = cx.update(SettingsStore::test);
|
||||||
cx.set_global(settings_store);
|
cx.set_global(settings_store);
|
||||||
cx.set_global(CompletionProvider::Fake(FakeCompletionProvider::default()));
|
cx.update(FakeCompletionProvider::setup_test);
|
||||||
cx.update(init);
|
cx.update(init);
|
||||||
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
let registry = Arc::new(LanguageRegistry::test(cx.executor()));
|
||||||
let context =
|
let context =
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
|
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
|
||||||
pub use anthropic::Model as AnthropicModel;
|
pub use anthropic::Model as AnthropicModel;
|
||||||
use gpui::Pixels;
|
use gpui::Pixels;
|
||||||
pub use ollama::Model as OllamaModel;
|
pub use ollama::Model as OllamaModel;
|
||||||
|
@ -15,8 +16,6 @@ use serde::{
|
||||||
use settings::{Settings, SettingsSources};
|
use settings::{Settings, SettingsSources};
|
||||||
use strum::{EnumIter, IntoEnumIterator};
|
use strum::{EnumIter, IntoEnumIterator};
|
||||||
|
|
||||||
use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
|
||||||
pub enum CloudModel {
|
pub enum CloudModel {
|
||||||
Gpt3Point5Turbo,
|
Gpt3Point5Turbo,
|
||||||
|
|
|
@ -11,6 +11,8 @@ pub use cloud::*;
|
||||||
pub use fake::*;
|
pub use fake::*;
|
||||||
pub use ollama::*;
|
pub use ollama::*;
|
||||||
pub use open_ai::*;
|
pub use open_ai::*;
|
||||||
|
use parking_lot::RwLock;
|
||||||
|
use smol::lock::{Semaphore, SemaphoreGuardArc};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::{AssistantProvider, AssistantSettings},
|
assistant_settings::{AssistantProvider, AssistantSettings},
|
||||||
|
@ -21,8 +23,8 @@ use client::Client;
|
||||||
use futures::{future::BoxFuture, stream::BoxStream};
|
use futures::{future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
|
||||||
use settings::{Settings, SettingsStore};
|
use settings::{Settings, SettingsStore};
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use std::{any::Any, sync::Arc};
|
||||||
|
|
||||||
/// Choose which model to use for openai provider.
|
/// Choose which model to use for openai provider.
|
||||||
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
/// If the model is not available, try to use the first available model, or fallback to the original model.
|
||||||
|
@ -39,176 +41,117 @@ fn choose_openai_model(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
|
||||||
let mut settings_version = 0;
|
let provider = create_provider_from_settings(client.clone(), 0, cx);
|
||||||
let provider = match &AssistantSettings::get_global(cx).provider {
|
cx.set_global(CompletionProvider::new(provider, Some(client)));
|
||||||
AssistantProvider::ZedDotDev { model } => CompletionProvider::Cloud(
|
|
||||||
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
|
||||||
),
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
} => CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
|
||||||
choose_openai_model(model, available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
)),
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
)),
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
} => CompletionProvider::Ollama(OllamaCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
cx.set_global(provider);
|
|
||||||
|
|
||||||
|
let mut settings_version = 0;
|
||||||
cx.observe_global::<SettingsStore>(move |cx| {
|
cx.observe_global::<SettingsStore>(move |cx| {
|
||||||
settings_version += 1;
|
settings_version += 1;
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||||
match (&mut *provider, &AssistantSettings::get_global(cx).provider) {
|
provider.update_settings(settings_version, cx);
|
||||||
(
|
|
||||||
CompletionProvider::OpenAi(provider),
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
provider.update(
|
|
||||||
choose_openai_model(model, available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
(
|
|
||||||
CompletionProvider::Anthropic(provider),
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
provider.update(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
(
|
|
||||||
CompletionProvider::Ollama(provider),
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
provider.update(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
(CompletionProvider::Cloud(provider), AssistantProvider::ZedDotDev { model }) => {
|
|
||||||
provider.update(model.clone(), settings_version);
|
|
||||||
}
|
|
||||||
(_, AssistantProvider::ZedDotDev { model }) => {
|
|
||||||
*provider = CompletionProvider::Cloud(CloudCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
client.clone(),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
AssistantProvider::OpenAi {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
available_models,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
*provider = CompletionProvider::OpenAi(OpenAiCompletionProvider::new(
|
|
||||||
choose_openai_model(model, available_models),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
AssistantProvider::Anthropic {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
*provider = CompletionProvider::Anthropic(AnthropicCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
AssistantProvider::Ollama {
|
|
||||||
model,
|
|
||||||
api_url,
|
|
||||||
low_speed_timeout_in_seconds,
|
|
||||||
},
|
|
||||||
) => {
|
|
||||||
*provider = CompletionProvider::Ollama(OllamaCompletionProvider::new(
|
|
||||||
model.clone(),
|
|
||||||
api_url.clone(),
|
|
||||||
client.http_client(),
|
|
||||||
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
|
||||||
settings_version,
|
|
||||||
cx,
|
|
||||||
));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub enum CompletionProvider {
|
pub struct CompletionResponse {
|
||||||
OpenAi(OpenAiCompletionProvider),
|
pub inner: BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>,
|
||||||
Anthropic(AnthropicCompletionProvider),
|
_lock: SemaphoreGuardArc,
|
||||||
Cloud(CloudCompletionProvider),
|
}
|
||||||
#[cfg(test)]
|
|
||||||
Fake(FakeCompletionProvider),
|
pub trait LanguageModelCompletionProvider: Send + Sync {
|
||||||
Ollama(OllamaCompletionProvider),
|
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
|
||||||
|
fn settings_version(&self) -> usize;
|
||||||
|
fn is_authenticated(&self) -> bool;
|
||||||
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||||
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
|
||||||
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>>;
|
||||||
|
fn model(&self) -> LanguageModel;
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>>;
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
|
||||||
|
|
||||||
|
fn as_any_mut(&mut self) -> &mut dyn Any;
|
||||||
|
}
|
||||||
|
|
||||||
|
const MAX_CONCURRENT_COMPLETION_REQUESTS: usize = 4;
|
||||||
|
|
||||||
|
pub struct CompletionProvider {
|
||||||
|
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||||
|
client: Option<Arc<Client>>,
|
||||||
|
request_limiter: Arc<Semaphore>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompletionProvider {
|
||||||
|
pub fn new(
|
||||||
|
provider: Arc<RwLock<dyn LanguageModelCompletionProvider>>,
|
||||||
|
client: Option<Arc<Client>>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
provider,
|
||||||
|
client,
|
||||||
|
request_limiter: Arc::new(Semaphore::new(MAX_CONCURRENT_COMPLETION_REQUESTS)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
||||||
|
self.provider.read().available_models(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn settings_version(&self) -> usize {
|
||||||
|
self.provider.read().settings_version()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_authenticated(&self) -> bool {
|
||||||
|
self.provider.read().is_authenticated()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
self.provider.read().authenticate(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
self.provider.read().authentication_prompt(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
self.provider.read().reset_credentials(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn model(&self) -> LanguageModel {
|
||||||
|
self.provider.read().model()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
self.provider.read().count_tokens(request, cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn complete(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
cx: &AppContext,
|
||||||
|
) -> Task<CompletionResponse> {
|
||||||
|
let rate_limiter = self.request_limiter.clone();
|
||||||
|
let provider = self.provider.clone();
|
||||||
|
cx.background_executor().spawn(async move {
|
||||||
|
let lock = rate_limiter.acquire_arc().await;
|
||||||
|
let response = provider.read().complete(request);
|
||||||
|
CompletionResponse {
|
||||||
|
inner: response,
|
||||||
|
_lock: lock,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl gpui::Global for CompletionProvider {}
|
impl gpui::Global for CompletionProvider {}
|
||||||
|
@ -218,121 +161,213 @@ impl CompletionProvider {
|
||||||
cx.global::<Self>()
|
cx.global::<Self>()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
pub fn update_current_as<R, T: LanguageModelCompletionProvider + 'static>(
|
||||||
match self {
|
&mut self,
|
||||||
CompletionProvider::OpenAi(provider) => provider
|
update: impl FnOnce(&mut T) -> R,
|
||||||
.available_models(cx)
|
) -> Option<R> {
|
||||||
.map(LanguageModel::OpenAi)
|
let mut provider = self.provider.write();
|
||||||
.collect(),
|
if let Some(provider) = provider.as_any_mut().downcast_mut::<T>() {
|
||||||
CompletionProvider::Anthropic(provider) => provider
|
Some(update(provider))
|
||||||
.available_models()
|
} else {
|
||||||
.map(LanguageModel::Anthropic)
|
None
|
||||||
.collect(),
|
|
||||||
CompletionProvider::Cloud(provider) => provider
|
|
||||||
.available_models()
|
|
||||||
.map(LanguageModel::Cloud)
|
|
||||||
.collect(),
|
|
||||||
CompletionProvider::Ollama(provider) => provider
|
|
||||||
.available_models()
|
|
||||||
.map(|model| LanguageModel::Ollama(model.clone()))
|
|
||||||
.collect(),
|
|
||||||
#[cfg(test)]
|
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
|
||||||
match self {
|
let updated = match &AssistantSettings::get_global(cx).provider {
|
||||||
CompletionProvider::OpenAi(provider) => provider.settings_version(),
|
AssistantProvider::ZedDotDev { model } => self
|
||||||
CompletionProvider::Anthropic(provider) => provider.settings_version(),
|
.update_current_as::<_, CloudCompletionProvider>(|provider| {
|
||||||
CompletionProvider::Cloud(provider) => provider.settings_version(),
|
provider.update(model.clone(), version);
|
||||||
CompletionProvider::Ollama(provider) => provider.settings_version(),
|
}),
|
||||||
#[cfg(test)]
|
AssistantProvider::OpenAi {
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
model,
|
||||||
}
|
api_url,
|
||||||
}
|
low_speed_timeout_in_seconds,
|
||||||
|
available_models,
|
||||||
|
} => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||||
|
provider.update(
|
||||||
|
choose_openai_model(&model, &available_models),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
version,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
AssistantProvider::Anthropic {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||||
|
provider.update(
|
||||||
|
model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
version,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
AssistantProvider::Ollama {
|
||||||
|
model,
|
||||||
|
api_url,
|
||||||
|
low_speed_timeout_in_seconds,
|
||||||
|
} => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||||
|
provider.update(
|
||||||
|
model.clone(),
|
||||||
|
api_url.clone(),
|
||||||
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
version,
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
// Previously configured provider was changed to another one
|
||||||
match self {
|
if updated.is_none() {
|
||||||
CompletionProvider::OpenAi(provider) => provider.is_authenticated(),
|
if let Some(client) = self.client.clone() {
|
||||||
CompletionProvider::Anthropic(provider) => provider.is_authenticated(),
|
self.provider = create_provider_from_settings(client, version, cx);
|
||||||
CompletionProvider::Cloud(provider) => provider.is_authenticated(),
|
} else {
|
||||||
CompletionProvider::Ollama(provider) => provider.is_authenticated(),
|
log::warn!("completion provider cannot be created because client is not set");
|
||||||
#[cfg(test)]
|
|
||||||
CompletionProvider::Fake(_) => true,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
}
|
||||||
match self {
|
|
||||||
CompletionProvider::OpenAi(provider) => provider.authenticate(cx),
|
fn create_provider_from_settings(
|
||||||
CompletionProvider::Anthropic(provider) => provider.authenticate(cx),
|
client: Arc<Client>,
|
||||||
CompletionProvider::Cloud(provider) => provider.authenticate(cx),
|
settings_version: usize,
|
||||||
CompletionProvider::Ollama(provider) => provider.authenticate(cx),
|
cx: &mut AppContext,
|
||||||
#[cfg(test)]
|
) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
|
||||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
match &AssistantSettings::get_global(cx).provider {
|
||||||
}
|
AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
|
||||||
}
|
CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
|
||||||
|
)),
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
AssistantProvider::OpenAi {
|
||||||
match self {
|
model,
|
||||||
CompletionProvider::OpenAi(provider) => provider.authentication_prompt(cx),
|
api_url,
|
||||||
CompletionProvider::Anthropic(provider) => provider.authentication_prompt(cx),
|
low_speed_timeout_in_seconds,
|
||||||
CompletionProvider::Cloud(provider) => provider.authentication_prompt(cx),
|
available_models,
|
||||||
CompletionProvider::Ollama(provider) => provider.authentication_prompt(cx),
|
} => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
|
||||||
#[cfg(test)]
|
choose_openai_model(&model, &available_models),
|
||||||
CompletionProvider::Fake(_) => unimplemented!(),
|
api_url.clone(),
|
||||||
}
|
client.http_client(),
|
||||||
}
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
))),
|
||||||
match self {
|
AssistantProvider::Anthropic {
|
||||||
CompletionProvider::OpenAi(provider) => provider.reset_credentials(cx),
|
model,
|
||||||
CompletionProvider::Anthropic(provider) => provider.reset_credentials(cx),
|
api_url,
|
||||||
CompletionProvider::Cloud(_) => Task::ready(Ok(())),
|
low_speed_timeout_in_seconds,
|
||||||
CompletionProvider::Ollama(provider) => provider.reset_credentials(cx),
|
} => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
|
||||||
#[cfg(test)]
|
model.clone(),
|
||||||
CompletionProvider::Fake(_) => Task::ready(Ok(())),
|
api_url.clone(),
|
||||||
}
|
client.http_client(),
|
||||||
}
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
pub fn model(&self) -> LanguageModel {
|
))),
|
||||||
match self {
|
AssistantProvider::Ollama {
|
||||||
CompletionProvider::OpenAi(provider) => LanguageModel::OpenAi(provider.model()),
|
model,
|
||||||
CompletionProvider::Anthropic(provider) => LanguageModel::Anthropic(provider.model()),
|
api_url,
|
||||||
CompletionProvider::Cloud(provider) => LanguageModel::Cloud(provider.model()),
|
low_speed_timeout_in_seconds,
|
||||||
CompletionProvider::Ollama(provider) => LanguageModel::Ollama(provider.model()),
|
} => Arc::new(RwLock::new(OllamaCompletionProvider::new(
|
||||||
#[cfg(test)]
|
model.clone(),
|
||||||
CompletionProvider::Fake(_) => LanguageModel::default(),
|
api_url.clone(),
|
||||||
}
|
client.http_client(),
|
||||||
}
|
low_speed_timeout_in_seconds.map(Duration::from_secs),
|
||||||
|
settings_version,
|
||||||
pub fn count_tokens(
|
cx,
|
||||||
&self,
|
))),
|
||||||
request: LanguageModelRequest,
|
}
|
||||||
cx: &AppContext,
|
}
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
|
||||||
match self {
|
#[cfg(test)]
|
||||||
CompletionProvider::OpenAi(provider) => provider.count_tokens(request, cx),
|
mod tests {
|
||||||
CompletionProvider::Anthropic(provider) => provider.count_tokens(request, cx),
|
use std::sync::Arc;
|
||||||
CompletionProvider::Cloud(provider) => provider.count_tokens(request, cx),
|
|
||||||
CompletionProvider::Ollama(provider) => provider.count_tokens(request, cx),
|
use gpui::AppContext;
|
||||||
#[cfg(test)]
|
use parking_lot::RwLock;
|
||||||
CompletionProvider::Fake(_) => futures::FutureExt::boxed(futures::future::ready(Ok(0))),
|
use settings::SettingsStore;
|
||||||
}
|
use smol::stream::StreamExt;
|
||||||
}
|
|
||||||
|
use crate::{
|
||||||
pub fn complete(
|
completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
|
||||||
&self,
|
FakeCompletionProvider, LanguageModelRequest,
|
||||||
request: LanguageModelRequest,
|
};
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
match self {
|
#[gpui::test]
|
||||||
CompletionProvider::OpenAi(provider) => provider.complete(request),
|
fn test_rate_limiting(cx: &mut AppContext) {
|
||||||
CompletionProvider::Anthropic(provider) => provider.complete(request),
|
SettingsStore::test(cx);
|
||||||
CompletionProvider::Cloud(provider) => provider.complete(request),
|
let fake_provider = FakeCompletionProvider::setup_test(cx);
|
||||||
CompletionProvider::Ollama(provider) => provider.complete(request),
|
|
||||||
#[cfg(test)]
|
let provider = CompletionProvider::new(Arc::new(RwLock::new(fake_provider.clone())), None);
|
||||||
CompletionProvider::Fake(provider) => provider.complete(),
|
|
||||||
}
|
// Enqueue some requests
|
||||||
|
for i in 0..MAX_CONCURRENT_COMPLETION_REQUESTS * 2 {
|
||||||
|
let response = provider.complete(
|
||||||
|
LanguageModelRequest {
|
||||||
|
temperature: i as f32 / 10.0,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
cx,
|
||||||
|
);
|
||||||
|
cx.background_executor()
|
||||||
|
.spawn(async move {
|
||||||
|
let response = response.await;
|
||||||
|
let mut stream = response.inner.await.unwrap();
|
||||||
|
while let Some(message) = stream.next().await {
|
||||||
|
message.unwrap();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.detach();
|
||||||
|
}
|
||||||
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
fake_provider.completion_count(),
|
||||||
|
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||||
|
);
|
||||||
|
|
||||||
|
// Get the first completion request that is in flight and mark it as completed.
|
||||||
|
let completion = fake_provider
|
||||||
|
.running_completions()
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.unwrap();
|
||||||
|
fake_provider.finish_completion(&completion);
|
||||||
|
|
||||||
|
// Ensure that the number of in-flight completion requests is reduced.
|
||||||
|
assert_eq!(
|
||||||
|
fake_provider.completion_count(),
|
||||||
|
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||||
|
);
|
||||||
|
|
||||||
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
|
// Ensure that another completion request was allowed to acquire the lock.
|
||||||
|
assert_eq!(
|
||||||
|
fake_provider.completion_count(),
|
||||||
|
MAX_CONCURRENT_COMPLETION_REQUESTS
|
||||||
|
);
|
||||||
|
|
||||||
|
// Mark all completion requests as finished that are in flight.
|
||||||
|
for request in fake_provider.running_completions() {
|
||||||
|
fake_provider.finish_completion(&request);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(fake_provider.completion_count(), 0);
|
||||||
|
|
||||||
|
// Wait until the background tasks acquire the lock again.
|
||||||
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
fake_provider.completion_count(),
|
||||||
|
MAX_CONCURRENT_COMPLETION_REQUESTS - 1
|
||||||
|
);
|
||||||
|
|
||||||
|
// Finish all remaining completion requests.
|
||||||
|
for request in fake_provider.running_completions() {
|
||||||
|
fake_provider.finish_completion(&request);
|
||||||
|
}
|
||||||
|
|
||||||
|
cx.background_executor().run_until_parked();
|
||||||
|
|
||||||
|
assert_eq!(fake_provider.completion_count(), 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
||||||
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
|
||||||
Role,
|
Role,
|
||||||
};
|
};
|
||||||
use crate::{count_open_ai_tokens, LanguageModelRequestMessage};
|
use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
|
||||||
use anthropic::{stream_completion, Request, RequestMessage};
|
use anthropic::{stream_completion, Request, RequestMessage};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use editor::{Editor, EditorElement, EditorStyle};
|
use editor::{Editor, EditorElement, EditorStyle};
|
||||||
|
@ -26,50 +26,22 @@ pub struct AnthropicCompletionProvider {
|
||||||
settings_version: usize,
|
settings_version: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicCompletionProvider {
|
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
|
||||||
pub fn new(
|
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||||
model: AnthropicModel,
|
|
||||||
api_url: String,
|
|
||||||
http_client: Arc<dyn HttpClient>,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
) -> Self {
|
|
||||||
Self {
|
|
||||||
api_key: None,
|
|
||||||
api_url,
|
|
||||||
model,
|
|
||||||
http_client,
|
|
||||||
low_speed_timeout,
|
|
||||||
settings_version,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn update(
|
|
||||||
&mut self,
|
|
||||||
model: AnthropicModel,
|
|
||||||
api_url: String,
|
|
||||||
low_speed_timeout: Option<Duration>,
|
|
||||||
settings_version: usize,
|
|
||||||
) {
|
|
||||||
self.model = model;
|
|
||||||
self.api_url = api_url;
|
|
||||||
self.low_speed_timeout = low_speed_timeout;
|
|
||||||
self.settings_version = settings_version;
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn available_models(&self) -> impl Iterator<Item = AnthropicModel> {
|
|
||||||
AnthropicModel::iter()
|
AnthropicModel::iter()
|
||||||
|
.map(LanguageModel::Anthropic)
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
fn settings_version(&self) -> usize {
|
||||||
self.settings_version
|
self.settings_version
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
fn is_authenticated(&self) -> bool {
|
||||||
self.api_key.is_some()
|
self.api_key.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated() {
|
if self.is_authenticated() {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -85,36 +57,36 @@ impl AnthropicCompletionProvider {
|
||||||
String::from_utf8(api_key)?
|
String::from_utf8(api_key)?
|
||||||
};
|
};
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::Anthropic(provider) = provider {
|
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||||
provider.api_key = Some(api_key);
|
provider.api_key = Some(api_key);
|
||||||
}
|
});
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
delete_credentials.await.log_err();
|
delete_credentials.await.log_err();
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::Anthropic(provider) = provider {
|
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||||
provider.api_key = None;
|
provider.api_key = None;
|
||||||
}
|
});
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn model(&self) -> AnthropicModel {
|
fn model(&self) -> LanguageModel {
|
||||||
self.model.clone()
|
LanguageModel::Anthropic(self.model.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_tokens(
|
fn count_tokens(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
|
@ -122,7 +94,7 @@ impl AnthropicCompletionProvider {
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
count_open_ai_tokens(request, cx.background_executor())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
@ -167,12 +139,48 @@ impl AnthropicCompletionProvider {
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AnthropicCompletionProvider {
|
||||||
|
pub fn new(
|
||||||
|
model: AnthropicModel,
|
||||||
|
api_url: String,
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
low_speed_timeout: Option<Duration>,
|
||||||
|
settings_version: usize,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
api_key: None,
|
||||||
|
api_url,
|
||||||
|
model,
|
||||||
|
http_client,
|
||||||
|
low_speed_timeout,
|
||||||
|
settings_version,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(
|
||||||
|
&mut self,
|
||||||
|
model: AnthropicModel,
|
||||||
|
api_url: String,
|
||||||
|
low_speed_timeout: Option<Duration>,
|
||||||
|
settings_version: usize,
|
||||||
|
) {
|
||||||
|
self.model = model;
|
||||||
|
self.api_url = api_url;
|
||||||
|
self.low_speed_timeout = low_speed_timeout;
|
||||||
|
self.settings_version = settings_version;
|
||||||
|
}
|
||||||
|
|
||||||
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
|
||||||
preprocess_anthropic_request(&mut request);
|
preprocess_anthropic_request(&mut request);
|
||||||
|
|
||||||
let model = match request.model {
|
let model = match request.model {
|
||||||
LanguageModel::Anthropic(model) => model,
|
LanguageModel::Anthropic(model) => model,
|
||||||
_ => self.model(),
|
_ => self.model.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut system_message = String::new();
|
let mut system_message = String::new();
|
||||||
|
@ -278,9 +286,9 @@ impl AuthenticationPrompt {
|
||||||
cx.spawn(|_, mut cx| async move {
|
cx.spawn(|_, mut cx| async move {
|
||||||
write_credentials.await?;
|
write_credentials.await?;
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::Anthropic(provider) = provider {
|
provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
|
||||||
provider.api_key = Some(api_key);
|
provider.api_key = Some(api_key);
|
||||||
}
|
});
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
|
||||||
LanguageModelRequest,
|
LanguageModelCompletionProvider, LanguageModelRequest,
|
||||||
};
|
};
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use client::{proto, Client};
|
use client::{proto, Client};
|
||||||
|
@ -30,11 +30,9 @@ impl CloudCompletionProvider {
|
||||||
let maintain_client_status = cx.spawn(|mut cx| async move {
|
let maintain_client_status = cx.spawn(|mut cx| async move {
|
||||||
while let Some(status) = status_rx.next().await {
|
while let Some(status) = status_rx.next().await {
|
||||||
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
let _ = cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::Cloud(provider) = provider {
|
provider.update_current_as::<_, Self>(|provider| {
|
||||||
provider.status = status;
|
provider.status = status;
|
||||||
} else {
|
});
|
||||||
unreachable!()
|
|
||||||
}
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -51,44 +49,53 @@ impl CloudCompletionProvider {
|
||||||
self.model = model;
|
self.model = model;
|
||||||
self.settings_version = settings_version;
|
self.settings_version = settings_version;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn available_models(&self) -> impl Iterator<Item = CloudModel> {
|
impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||||
|
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||||
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
|
||||||
Some(custom_model)
|
Some(custom_model)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
CloudModel::iter().filter_map(move |model| {
|
CloudModel::iter()
|
||||||
|
.filter_map(move |model| {
|
||||||
if let CloudModel::Custom(_) = model {
|
if let CloudModel::Custom(_) = model {
|
||||||
Some(CloudModel::Custom(custom_model.take()?))
|
Some(CloudModel::Custom(custom_model.take()?))
|
||||||
} else {
|
} else {
|
||||||
Some(model)
|
Some(model)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
.map(LanguageModel::Cloud)
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
fn settings_version(&self) -> usize {
|
||||||
self.settings_version
|
self.settings_version
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn model(&self) -> CloudModel {
|
fn is_authenticated(&self) -> bool {
|
||||||
self.model.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
|
||||||
self.status.is_connected()
|
self.status.is_connected()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
let client = self.client.clone();
|
let client = self.client.clone();
|
||||||
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
cx.spawn(move |cx| async move { client.authenticate_and_connect(true, &cx).await })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
cx.new_view(|_cx| AuthenticationPrompt).into()
|
cx.new_view(|_cx| AuthenticationPrompt).into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_tokens(
|
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> LanguageModel {
|
||||||
|
LanguageModel::Cloud(self.model.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
|
@ -128,7 +135,7 @@ impl CloudCompletionProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
mut request: LanguageModelRequest,
|
mut request: LanguageModelRequest,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
@ -161,6 +168,10 @@ impl CloudCompletionProvider {
|
||||||
})
|
})
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct AuthenticationPrompt;
|
struct AuthenticationPrompt;
|
||||||
|
|
|
@ -1,29 +1,107 @@
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use collections::HashMap;
|
||||||
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
|
||||||
|
use gpui::{AnyView, AppContext, Task};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use ui::WindowContext;
|
||||||
|
|
||||||
|
use crate::{LanguageModel, LanguageModelCompletionProvider, LanguageModelRequest};
|
||||||
|
|
||||||
#[derive(Clone, Default)]
|
#[derive(Clone, Default)]
|
||||||
pub struct FakeCompletionProvider {
|
pub struct FakeCompletionProvider {
|
||||||
current_completion_tx: Arc<parking_lot::Mutex<Option<mpsc::UnboundedSender<String>>>>,
|
current_completion_txs: Arc<parking_lot::Mutex<HashMap<String, mpsc::UnboundedSender<String>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl FakeCompletionProvider {
|
impl FakeCompletionProvider {
|
||||||
pub fn complete(&self) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
#[cfg(test)]
|
||||||
let (tx, rx) = mpsc::unbounded();
|
pub fn setup_test(cx: &mut AppContext) -> Self {
|
||||||
*self.current_completion_tx.lock() = Some(tx);
|
use crate::CompletionProvider;
|
||||||
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
use parking_lot::RwLock;
|
||||||
|
|
||||||
|
let this = Self::default();
|
||||||
|
let provider = CompletionProvider::new(Arc::new(RwLock::new(this.clone())), None);
|
||||||
|
cx.set_global(provider);
|
||||||
|
this
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn send_completion(&self, chunk: String) {
|
pub fn running_completions(&self) -> Vec<LanguageModelRequest> {
|
||||||
self.current_completion_tx
|
self.current_completion_txs
|
||||||
.lock()
|
.lock()
|
||||||
.as_ref()
|
.keys()
|
||||||
|
.map(|k| serde_json::from_str(k).unwrap())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn completion_count(&self) -> usize {
|
||||||
|
self.current_completion_txs.lock().len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn send_completion(&self, request: &LanguageModelRequest, chunk: String) {
|
||||||
|
let json = serde_json::to_string(request).unwrap();
|
||||||
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.get(&json)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unbounded_send(chunk)
|
.unbounded_send(chunk)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn finish_completion(&self) {
|
pub fn finish_completion(&self, request: &LanguageModelRequest) {
|
||||||
self.current_completion_tx.lock().take();
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.remove(&serde_json::to_string(request).unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelCompletionProvider for FakeCompletionProvider {
|
||||||
|
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||||
|
vec![LanguageModel::default()]
|
||||||
|
}
|
||||||
|
|
||||||
|
fn settings_version(&self) -> usize {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_authenticated(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(&self, _cx: &mut WindowContext) -> AnyView {
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_credentials(&self, _cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> LanguageModel {
|
||||||
|
LanguageModel::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
_request: LanguageModelRequest,
|
||||||
|
_cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
futures::future::ready(Ok(0)).boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
_request: LanguageModelRequest,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let (tx, rx) = mpsc::unbounded();
|
||||||
|
self.current_completion_txs
|
||||||
|
.lock()
|
||||||
|
.insert(serde_json::to_string(&_request).unwrap(), tx);
|
||||||
|
async move { Ok(rx.map(Ok).boxed()) }.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||||
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::LanguageModelCompletionProvider;
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||||
};
|
};
|
||||||
|
@ -26,6 +27,108 @@ pub struct OllamaCompletionProvider {
|
||||||
available_models: Vec<OllamaModel>,
|
available_models: Vec<OllamaModel>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
|
||||||
|
fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
|
||||||
|
self.available_models
|
||||||
|
.iter()
|
||||||
|
.map(|m| LanguageModel::Ollama(m.clone()))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn settings_version(&self) -> usize {
|
||||||
|
self.settings_version
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_authenticated(&self) -> bool {
|
||||||
|
!self.available_models.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
if self.is_authenticated() {
|
||||||
|
Task::ready(Ok(()))
|
||||||
|
} else {
|
||||||
|
self.fetch_models(cx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
|
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
||||||
|
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
||||||
|
provider
|
||||||
|
.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||||
|
provider.fetch_models(cx)
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| Task::ready(Ok(())))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
||||||
|
.into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
|
self.fetch_models(cx)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> LanguageModel {
|
||||||
|
LanguageModel::Ollama(self.model.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn count_tokens(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
_cx: &AppContext,
|
||||||
|
) -> BoxFuture<'static, Result<usize>> {
|
||||||
|
// There is no endpoint for this _yet_ in Ollama
|
||||||
|
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
||||||
|
let token_count = request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|msg| msg.content.chars().count())
|
||||||
|
.sum::<usize>()
|
||||||
|
/ 4;
|
||||||
|
|
||||||
|
async move { Ok(token_count) }.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn complete(
|
||||||
|
&self,
|
||||||
|
request: LanguageModelRequest,
|
||||||
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
let request = self.to_ollama_request(request);
|
||||||
|
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let api_url = self.api_url.clone();
|
||||||
|
let low_speed_timeout = self.low_speed_timeout;
|
||||||
|
async move {
|
||||||
|
let request =
|
||||||
|
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||||
|
let response = request.await?;
|
||||||
|
let stream = response
|
||||||
|
.filter_map(|response| async move {
|
||||||
|
match response {
|
||||||
|
Ok(delta) => {
|
||||||
|
let content = match delta.message {
|
||||||
|
ChatMessage::User { content } => content,
|
||||||
|
ChatMessage::Assistant { content } => content,
|
||||||
|
ChatMessage::System { content } => content,
|
||||||
|
};
|
||||||
|
Some(Ok(content))
|
||||||
|
}
|
||||||
|
Err(error) => Some(Err(error)),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
Ok(stream)
|
||||||
|
}
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl OllamaCompletionProvider {
|
impl OllamaCompletionProvider {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
model: OllamaModel,
|
model: OllamaModel,
|
||||||
|
@ -87,36 +190,12 @@ impl OllamaCompletionProvider {
|
||||||
self.settings_version = settings_version;
|
self.settings_version = settings_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
|
|
||||||
self.available_models.iter()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn select_first_available_model(&mut self) {
|
pub fn select_first_available_model(&mut self) {
|
||||||
if let Some(model) = self.available_models.first() {
|
if let Some(model) = self.available_models.first() {
|
||||||
self.model = model.clone();
|
self.model = model.clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
|
||||||
self.settings_version
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
|
||||||
!self.available_models.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
if self.is_authenticated() {
|
|
||||||
Task::ready(Ok(()))
|
|
||||||
} else {
|
|
||||||
self.fetch_models(cx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
|
||||||
self.fetch_models(cx)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_url = self.api_url.clone();
|
let api_url = self.api_url.clone();
|
||||||
|
@ -137,90 +216,21 @@ impl OllamaCompletionProvider {
|
||||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::Ollama(provider) = provider {
|
provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
|
||||||
provider.available_models = models;
|
provider.available_models = models;
|
||||||
|
|
||||||
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
|
if !provider.available_models.is_empty() && provider.model.name.is_empty() {
|
||||||
provider.select_first_available_model()
|
provider.select_first_available_model()
|
||||||
}
|
}
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
|
||||||
let fetch_models = Box::new(move |cx: &mut WindowContext| {
|
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, cx| {
|
|
||||||
if let CompletionProvider::Ollama(provider) = provider {
|
|
||||||
provider.fetch_models(cx)
|
|
||||||
} else {
|
|
||||||
Task::ready(Ok(()))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
});
|
});
|
||||||
|
|
||||||
cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
|
|
||||||
.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn model(&self) -> OllamaModel {
|
|
||||||
self.model.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn count_tokens(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
_cx: &AppContext,
|
|
||||||
) -> BoxFuture<'static, Result<usize>> {
|
|
||||||
// There is no endpoint for this _yet_ in Ollama
|
|
||||||
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
|
|
||||||
let token_count = request
|
|
||||||
.messages
|
|
||||||
.iter()
|
|
||||||
.map(|msg| msg.content.chars().count())
|
|
||||||
.sum::<usize>()
|
|
||||||
/ 4;
|
|
||||||
|
|
||||||
async move { Ok(token_count) }.boxed()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn complete(
|
|
||||||
&self,
|
|
||||||
request: LanguageModelRequest,
|
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
|
||||||
let request = self.to_ollama_request(request);
|
|
||||||
|
|
||||||
let http_client = self.http_client.clone();
|
|
||||||
let api_url = self.api_url.clone();
|
|
||||||
let low_speed_timeout = self.low_speed_timeout;
|
|
||||||
async move {
|
|
||||||
let request =
|
|
||||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
|
||||||
let response = request.await?;
|
|
||||||
let stream = response
|
|
||||||
.filter_map(|response| async move {
|
|
||||||
match response {
|
|
||||||
Ok(delta) => {
|
|
||||||
let content = match delta.message {
|
|
||||||
ChatMessage::User { content } => content,
|
|
||||||
ChatMessage::Assistant { content } => content,
|
|
||||||
ChatMessage::System { content } => content,
|
|
||||||
};
|
|
||||||
Some(Ok(content))
|
|
||||||
}
|
|
||||||
Err(error) => Some(Err(error)),
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.boxed();
|
})
|
||||||
Ok(stream)
|
|
||||||
}
|
|
||||||
.boxed()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
|
||||||
let model = match request.model {
|
let model = match request.model {
|
||||||
LanguageModel::Ollama(model) => model,
|
LanguageModel::Ollama(model) => model,
|
||||||
_ => self.model(),
|
_ => self.model.clone(),
|
||||||
};
|
};
|
||||||
|
|
||||||
ChatRequest {
|
ChatRequest {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::assistant_settings::CloudModel;
|
use crate::assistant_settings::CloudModel;
|
||||||
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
|
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
|
||||||
|
use crate::LanguageModelCompletionProvider;
|
||||||
use crate::{
|
use crate::{
|
||||||
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
|
||||||
};
|
};
|
||||||
|
@ -57,37 +58,75 @@ impl OpenAiCompletionProvider {
|
||||||
self.settings_version = settings_version;
|
self.settings_version = settings_version;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
|
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
||||||
|
let model = match request.model {
|
||||||
|
LanguageModel::OpenAi(model) => model,
|
||||||
|
_ => self.model.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Request {
|
||||||
|
model,
|
||||||
|
messages: request
|
||||||
|
.messages
|
||||||
|
.into_iter()
|
||||||
|
.map(|msg| match msg.role {
|
||||||
|
Role::User => RequestMessage::User {
|
||||||
|
content: msg.content,
|
||||||
|
},
|
||||||
|
Role::Assistant => RequestMessage::Assistant {
|
||||||
|
content: Some(msg.content),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
},
|
||||||
|
Role::System => RequestMessage::System {
|
||||||
|
content: msg.content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
stream: true,
|
||||||
|
stop: request.stop,
|
||||||
|
temperature: request.temperature,
|
||||||
|
tools: Vec::new(),
|
||||||
|
tool_choice: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
|
||||||
|
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
|
||||||
if let AssistantProvider::OpenAi {
|
if let AssistantProvider::OpenAi {
|
||||||
available_models, ..
|
available_models, ..
|
||||||
} = &AssistantSettings::get_global(cx).provider
|
} = &AssistantSettings::get_global(cx).provider
|
||||||
{
|
{
|
||||||
if !available_models.is_empty() {
|
if !available_models.is_empty() {
|
||||||
// available_models is set, just return it
|
return available_models
|
||||||
return available_models.clone().into_iter();
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.map(LanguageModel::OpenAi)
|
||||||
|
.collect();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
|
||||||
// available_models is not set but the default model is set to custom, only show custom
|
|
||||||
vec![self.model.clone()]
|
vec![self.model.clone()]
|
||||||
} else {
|
} else {
|
||||||
// default case, use all models except custom
|
|
||||||
OpenAiModel::iter()
|
OpenAiModel::iter()
|
||||||
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
|
||||||
.collect()
|
.collect()
|
||||||
};
|
};
|
||||||
available_models.into_iter()
|
available_models
|
||||||
|
.into_iter()
|
||||||
|
.map(LanguageModel::OpenAi)
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn settings_version(&self) -> usize {
|
fn settings_version(&self) -> usize {
|
||||||
self.settings_version
|
self.settings_version
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_authenticated(&self) -> bool {
|
fn is_authenticated(&self) -> bool {
|
||||||
self.api_key.is_some()
|
self.api_key.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
if self.is_authenticated() {
|
if self.is_authenticated() {
|
||||||
Task::ready(Ok(()))
|
Task::ready(Ok(()))
|
||||||
} else {
|
} else {
|
||||||
|
@ -103,36 +142,36 @@ impl OpenAiCompletionProvider {
|
||||||
String::from_utf8(api_key)?
|
String::from_utf8(api_key)?
|
||||||
};
|
};
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::OpenAi(provider) = provider {
|
provider.update_current_as::<_, Self>(|provider| {
|
||||||
provider.api_key = Some(api_key);
|
provider.api_key = Some(api_key);
|
||||||
}
|
});
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||||
let delete_credentials = cx.delete_credentials(&self.api_url);
|
let delete_credentials = cx.delete_credentials(&self.api_url);
|
||||||
cx.spawn(|mut cx| async move {
|
cx.spawn(|mut cx| async move {
|
||||||
delete_credentials.await.log_err();
|
delete_credentials.await.log_err();
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::OpenAi(provider) = provider {
|
provider.update_current_as::<_, Self>(|provider| {
|
||||||
provider.api_key = None;
|
provider.api_key = None;
|
||||||
}
|
});
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
|
||||||
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
|
||||||
.into()
|
.into()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn model(&self) -> OpenAiModel {
|
fn model(&self) -> LanguageModel {
|
||||||
self.model.clone()
|
LanguageModel::OpenAi(self.model.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_tokens(
|
fn count_tokens(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
cx: &AppContext,
|
cx: &AppContext,
|
||||||
|
@ -140,7 +179,7 @@ impl OpenAiCompletionProvider {
|
||||||
count_open_ai_tokens(request, cx.background_executor())
|
count_open_ai_tokens(request, cx.background_executor())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn complete(
|
fn complete(
|
||||||
&self,
|
&self,
|
||||||
request: LanguageModelRequest,
|
request: LanguageModelRequest,
|
||||||
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
|
||||||
|
@ -173,36 +212,8 @@ impl OpenAiCompletionProvider {
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
|
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
|
||||||
let model = match request.model {
|
self
|
||||||
LanguageModel::OpenAi(model) => model,
|
|
||||||
_ => self.model(),
|
|
||||||
};
|
|
||||||
|
|
||||||
Request {
|
|
||||||
model,
|
|
||||||
messages: request
|
|
||||||
.messages
|
|
||||||
.into_iter()
|
|
||||||
.map(|msg| match msg.role {
|
|
||||||
Role::User => RequestMessage::User {
|
|
||||||
content: msg.content,
|
|
||||||
},
|
|
||||||
Role::Assistant => RequestMessage::Assistant {
|
|
||||||
content: Some(msg.content),
|
|
||||||
tool_calls: Vec::new(),
|
|
||||||
},
|
|
||||||
Role::System => RequestMessage::System {
|
|
||||||
content: msg.content,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
.collect(),
|
|
||||||
stream: true,
|
|
||||||
stop: request.stop,
|
|
||||||
temperature: request.temperature,
|
|
||||||
tools: Vec::new(),
|
|
||||||
tool_choice: None,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -284,9 +295,9 @@ impl AuthenticationPrompt {
|
||||||
cx.spawn(|_, mut cx| async move {
|
cx.spawn(|_, mut cx| async move {
|
||||||
write_credentials.await?;
|
write_credentials.await?;
|
||||||
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
|
||||||
if let CompletionProvider::OpenAi(provider) = provider {
|
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
|
||||||
provider.api_key = Some(api_key);
|
provider.api_key = Some(api_key);
|
||||||
}
|
});
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.detach_and_log_err(cx);
|
.detach_and_log_err(cx);
|
||||||
|
|
|
@ -1986,13 +1986,14 @@ impl Codegen {
|
||||||
.unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
|
.unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
|
||||||
|
|
||||||
let model_telemetry_id = prompt.model.telemetry_id();
|
let model_telemetry_id = prompt.model.telemetry_id();
|
||||||
let response = CompletionProvider::global(cx).complete(prompt);
|
let response = CompletionProvider::global(cx).complete(prompt, cx);
|
||||||
let telemetry = self.telemetry.clone();
|
let telemetry = self.telemetry.clone();
|
||||||
self.edit_position = range.start;
|
self.edit_position = range.start;
|
||||||
self.diff = Diff::default();
|
self.diff = Diff::default();
|
||||||
self.status = CodegenStatus::Pending;
|
self.status = CodegenStatus::Pending;
|
||||||
self.generation = cx.spawn(|this, mut cx| {
|
self.generation = cx.spawn(|this, mut cx| {
|
||||||
async move {
|
async move {
|
||||||
|
let response = response.await;
|
||||||
let generate = async {
|
let generate = async {
|
||||||
let mut edit_start = range.start.to_offset(&snapshot);
|
let mut edit_start = range.start.to_offset(&snapshot);
|
||||||
|
|
||||||
|
@ -2002,7 +2003,7 @@ impl Codegen {
|
||||||
let mut response_latency = None;
|
let mut response_latency = None;
|
||||||
let request_start = Instant::now();
|
let request_start = Instant::now();
|
||||||
let diff = async {
|
let diff = async {
|
||||||
let chunks = StripInvalidSpans::new(response.await?);
|
let chunks = StripInvalidSpans::new(response.inner.await?);
|
||||||
futures::pin_mut!(chunks);
|
futures::pin_mut!(chunks);
|
||||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||||
|
|
||||||
|
@ -2473,9 +2474,8 @@ mod tests {
|
||||||
|
|
||||||
#[gpui::test(iterations = 10)]
|
#[gpui::test(iterations = 10)]
|
||||||
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
|
||||||
let provider = FakeCompletionProvider::default();
|
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
let text = indoc! {"
|
let text = indoc! {"
|
||||||
|
@ -2495,8 +2495,11 @@ mod tests {
|
||||||
});
|
});
|
||||||
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
|
let codegen = cx.new_model(|cx| Codegen::new(buffer.clone(), range, None, cx));
|
||||||
|
|
||||||
let request = LanguageModelRequest::default();
|
codegen.update(cx, |codegen, cx| {
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
codegen.start(LanguageModelRequest::default(), cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
" let mut x = 0;\n",
|
" let mut x = 0;\n",
|
||||||
|
@ -2508,11 +2511,11 @@ mod tests {
|
||||||
let max_len = cmp::min(new_text.len(), 10);
|
let max_len = cmp::min(new_text.len(), 10);
|
||||||
let len = rng.gen_range(1..=max_len);
|
let len = rng.gen_range(1..=max_len);
|
||||||
let (chunk, suffix) = new_text.split_at(len);
|
let (chunk, suffix) = new_text.split_at(len);
|
||||||
provider.send_completion(chunk.into());
|
provider.send_completion(&LanguageModelRequest::default(), chunk.into());
|
||||||
new_text = suffix;
|
new_text = suffix;
|
||||||
cx.background_executor.run_until_parked();
|
cx.background_executor.run_until_parked();
|
||||||
}
|
}
|
||||||
provider.finish_completion();
|
provider.finish_completion(&LanguageModelRequest::default());
|
||||||
cx.background_executor.run_until_parked();
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -2533,8 +2536,7 @@ mod tests {
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
mut rng: StdRng,
|
mut rng: StdRng,
|
||||||
) {
|
) {
|
||||||
let provider = FakeCompletionProvider::default();
|
let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
||||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
|
@ -2555,6 +2557,8 @@ mod tests {
|
||||||
let request = LanguageModelRequest::default();
|
let request = LanguageModelRequest::default();
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"t mut x = 0;\n",
|
"t mut x = 0;\n",
|
||||||
"while x < 10 {\n",
|
"while x < 10 {\n",
|
||||||
|
@ -2565,11 +2569,11 @@ mod tests {
|
||||||
let max_len = cmp::min(new_text.len(), 10);
|
let max_len = cmp::min(new_text.len(), 10);
|
||||||
let len = rng.gen_range(1..=max_len);
|
let len = rng.gen_range(1..=max_len);
|
||||||
let (chunk, suffix) = new_text.split_at(len);
|
let (chunk, suffix) = new_text.split_at(len);
|
||||||
provider.send_completion(chunk.into());
|
provider.send_completion(&LanguageModelRequest::default(), chunk.into());
|
||||||
new_text = suffix;
|
new_text = suffix;
|
||||||
cx.background_executor.run_until_parked();
|
cx.background_executor.run_until_parked();
|
||||||
}
|
}
|
||||||
provider.finish_completion();
|
provider.finish_completion(&LanguageModelRequest::default());
|
||||||
cx.background_executor.run_until_parked();
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -2590,8 +2594,7 @@ mod tests {
|
||||||
cx: &mut TestAppContext,
|
cx: &mut TestAppContext,
|
||||||
mut rng: StdRng,
|
mut rng: StdRng,
|
||||||
) {
|
) {
|
||||||
let provider = FakeCompletionProvider::default();
|
let provider = cx.update(|cx| FakeCompletionProvider::setup_test(cx));
|
||||||
cx.set_global(CompletionProvider::Fake(provider.clone()));
|
|
||||||
cx.set_global(cx.update(SettingsStore::test));
|
cx.set_global(cx.update(SettingsStore::test));
|
||||||
cx.update(language_settings::init);
|
cx.update(language_settings::init);
|
||||||
|
|
||||||
|
@ -2612,6 +2615,8 @@ mod tests {
|
||||||
let request = LanguageModelRequest::default();
|
let request = LanguageModelRequest::default();
|
||||||
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
codegen.update(cx, |codegen, cx| codegen.start(request, cx));
|
||||||
|
|
||||||
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
let mut new_text = concat!(
|
let mut new_text = concat!(
|
||||||
"let mut x = 0;\n",
|
"let mut x = 0;\n",
|
||||||
"while x < 10 {\n",
|
"while x < 10 {\n",
|
||||||
|
@ -2622,11 +2627,11 @@ mod tests {
|
||||||
let max_len = cmp::min(new_text.len(), 10);
|
let max_len = cmp::min(new_text.len(), 10);
|
||||||
let len = rng.gen_range(1..=max_len);
|
let len = rng.gen_range(1..=max_len);
|
||||||
let (chunk, suffix) = new_text.split_at(len);
|
let (chunk, suffix) = new_text.split_at(len);
|
||||||
provider.send_completion(chunk.into());
|
provider.send_completion(&LanguageModelRequest::default(), chunk.into());
|
||||||
new_text = suffix;
|
new_text = suffix;
|
||||||
cx.background_executor.run_until_parked();
|
cx.background_executor.run_until_parked();
|
||||||
}
|
}
|
||||||
provider.finish_completion();
|
provider.finish_completion(&LanguageModelRequest::default());
|
||||||
cx.background_executor.run_until_parked();
|
cx.background_executor.run_until_parked();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
|
|
@ -1026,9 +1026,10 @@ impl Codegen {
|
||||||
|
|
||||||
let telemetry = self.telemetry.clone();
|
let telemetry = self.telemetry.clone();
|
||||||
let model_telemetry_id = prompt.model.telemetry_id();
|
let model_telemetry_id = prompt.model.telemetry_id();
|
||||||
let response = CompletionProvider::global(cx).complete(prompt);
|
let response = CompletionProvider::global(cx).complete(prompt, cx);
|
||||||
|
|
||||||
self.generation = cx.spawn(|this, mut cx| async move {
|
self.generation = cx.spawn(|this, mut cx| async move {
|
||||||
|
let response = response.await;
|
||||||
let generate = async {
|
let generate = async {
|
||||||
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
|
||||||
|
|
||||||
|
@ -1036,7 +1037,7 @@ impl Codegen {
|
||||||
let mut response_latency = None;
|
let mut response_latency = None;
|
||||||
let request_start = Instant::now();
|
let request_start = Instant::now();
|
||||||
let task = async {
|
let task = async {
|
||||||
let mut response = response.await?;
|
let mut response = response.inner.await?;
|
||||||
while let Some(chunk) = response.next().await {
|
while let Some(chunk) = response.next().await {
|
||||||
if response_latency.is_none() {
|
if response_latency.is_none() {
|
||||||
response_latency = Some(request_start.elapsed());
|
response_latency = Some(request_start.elapsed());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue