assistant: Limit amount of concurrent completion requests (#13856)

This PR refactors the completion providers to only process a maximum
amount of completion requests at a time.

Also started refactoring language model providers to use traits, so it's
easier to allow specifying multiple providers in the future.

Release Notes:

- N/A
This commit is contained in:
Bennet Bo Fenner 2024-07-05 14:52:45 +02:00 committed by GitHub
parent f2711b2fca
commit c4dbe32f20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 693 additions and 532 deletions

View file

@ -1,5 +1,6 @@
use crate::assistant_settings::CloudModel;
use crate::assistant_settings::{AssistantProvider, AssistantSettings};
use crate::LanguageModelCompletionProvider;
use crate::{
assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
};
@ -57,37 +58,75 @@ impl OpenAiCompletionProvider {
self.settings_version = settings_version;
}
pub fn available_models(&self, cx: &AppContext) -> impl Iterator<Item = OpenAiModel> {
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::OpenAi(model) => model,
_ => self.model.clone(),
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
}
}
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
if let AssistantProvider::OpenAi {
available_models, ..
} = &AssistantSettings::get_global(cx).provider
{
if !available_models.is_empty() {
// available_models is set, just return it
return available_models.clone().into_iter();
return available_models
.iter()
.cloned()
.map(LanguageModel::OpenAi)
.collect();
}
}
let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
// available_models is not set but the default model is set to custom, only show custom
vec![self.model.clone()]
} else {
// default case, use all models except custom
OpenAiModel::iter()
.filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
.collect()
};
available_models.into_iter()
available_models
.into_iter()
.map(LanguageModel::OpenAi)
.collect()
}
pub fn settings_version(&self) -> usize {
fn settings_version(&self) -> usize {
self.settings_version
}
pub fn is_authenticated(&self) -> bool {
fn is_authenticated(&self) -> bool {
self.api_key.is_some()
}
pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
@ -103,36 +142,36 @@ impl OpenAiCompletionProvider {
String::from_utf8(api_key)?
};
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.update_current_as::<_, Self>(|provider| {
provider.api_key = Some(api_key);
}
});
})
})
}
}
pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let delete_credentials = cx.delete_credentials(&self.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.update_current_as::<_, Self>(|provider| {
provider.api_key = None;
}
});
})
})
}
pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
cx.new_view(|cx| AuthenticationPrompt::new(self.api_url.clone(), cx))
.into()
}
pub fn model(&self) -> OpenAiModel {
self.model.clone()
fn model(&self) -> LanguageModel {
LanguageModel::OpenAi(self.model.clone())
}
pub fn count_tokens(
fn count_tokens(
&self,
request: LanguageModelRequest,
cx: &AppContext,
@ -140,7 +179,7 @@ impl OpenAiCompletionProvider {
count_open_ai_tokens(request, cx.background_executor())
}
pub fn complete(
fn complete(
&self,
request: LanguageModelRequest,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
@ -173,36 +212,8 @@ impl OpenAiCompletionProvider {
.boxed()
}
fn to_open_ai_request(&self, request: LanguageModelRequest) -> Request {
let model = match request.model {
LanguageModel::OpenAi(model) => model,
_ => self.model(),
};
Request {
model,
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => RequestMessage::User {
content: msg.content,
},
Role::Assistant => RequestMessage::Assistant {
content: Some(msg.content),
tool_calls: Vec::new(),
},
Role::System => RequestMessage::System {
content: msg.content,
},
})
.collect(),
stream: true,
stop: request.stop,
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
}
@ -284,9 +295,9 @@ impl AuthenticationPrompt {
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
cx.update_global::<CompletionProvider, _>(|provider, _cx| {
if let CompletionProvider::OpenAi(provider) = provider {
provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
provider.api_key = Some(api_key);
}
});
})
})
.detach_and_log_err(cx);