Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -27,7 +27,7 @@ use crate::settings::AllLanguageModelSettings;
use crate::LanguageModelProviderState;
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, Role,
LanguageModelProviderId, LanguageModelProviderName, LanguageModelRequest, RateLimiter, Role,
};
use super::open_ai::count_open_ai_tokens;
@ -85,7 +85,12 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
fn provided_models(&self, _cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
CopilotChatModel::iter()
.map(|model| Arc::new(CopilotChatLanguageModel { model }) as Arc<dyn LanguageModel>)
.map(|model| {
Arc::new(CopilotChatLanguageModel {
model,
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
@ -95,7 +100,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
.unwrap_or(false)
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
let result = if self.is_authenticated(cx) {
Ok(())
} else if let Some(copilot) = Copilot::global(cx) {
@ -121,7 +126,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
cx.new_view(|cx| AuthenticationPrompt::new(cx)).into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let Some(copilot) = Copilot::global(cx) else {
return Task::ready(Err(anyhow::anyhow!(
"Copilot is not available. Please ensure Copilot is enabled and running and try again."
@ -145,6 +150,7 @@ impl LanguageModelProvider for CopilotChatLanguageModelProvider {
pub struct CopilotChatLanguageModel {
model: CopilotChatModel,
request_limiter: RateLimiter,
}
impl LanguageModel for CopilotChatLanguageModel {
@ -215,30 +221,35 @@ impl LanguageModel for CopilotChatLanguageModel {
return futures::future::ready(Err(anyhow::anyhow!("App state dropped"))).boxed();
};
cx.spawn(|mut cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, &mut cx).await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(|cx| async move {
let response = CopilotChat::stream_completion(request, low_speed_timeout, cx);
request_limiter.stream(async move {
let response = response.await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(result) => {
let choice = result.choices.first();
match choice {
Some(choice) => Some(Ok(choice.delta.content.clone().unwrap_or_default())),
None => Some(Err(anyhow::anyhow!(
"The Copilot Chat API returned a response with no choices, but hadn't finished the message yet. Please try again."
))),
}
}
Err(err) => Some(Err(err)),
}
Err(err) => Some(Err(err)),
}
})
.boxed();
Ok(stream)
})
.boxed()
})
.boxed();
Ok(stream)
}).await
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,