Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -20,7 +20,7 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest,
LanguageModelProviderState, LanguageModelRequest, RateLimiter,
};
const PROVIDER_ID: &str = "google";
@ -111,6 +111,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
model,
state: self.state.clone(),
http_client: self.http_client.clone(),
rate_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -120,7 +121,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
self.state.read(cx).api_key.is_some()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
@ -153,7 +154,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
let state = self.state.clone();
let delete_credentials =
cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).google.api_url);
@ -172,6 +173,7 @@ pub struct GoogleLanguageModel {
model: google_ai::Model,
state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
rate_limiter: RateLimiter,
}
impl LanguageModel for GoogleLanguageModel {
@ -243,17 +245,17 @@ impl LanguageModel for GoogleLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let future = self.rate_limiter.stream(async move {
let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
let response =
stream_generate_content(http_client.as_ref(), &api_url, &api_key, request);
let events = response.await?;
Ok(google_ai::extract_text_from_events(events).boxed())
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,