Allow customization of the model used for tool calling (#15479)

We also eliminate the `completion` crate and moved its logic into
`LanguageModelRegistry`.

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
Antonio Scandurra 2024-07-30 16:18:53 +02:00 committed by GitHub
parent 1bfea9d443
commit 99bc90a372
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
32 changed files with 478 additions and 691 deletions

View file

@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, Role,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
@ -39,7 +39,7 @@ struct State {
}
impl State {
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
}),
}),
};
this.fetch_models(cx).detach();
this.state
.update(cx, |state, cx| state.fetch_models(cx).detach());
this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
state.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
!self.state.read(cx).available_models.is_empty()
}
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
.into()
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
self.fetch_models(cx)
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OllamaLanguageModel {
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move {
let request =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
let response = request.await?;
let future = self.request_limiter.stream(async move {
let response =
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
.await?;
let stream = response
.filter_map(|response| async move {
match response {
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
})
.boxed();
Ok(stream)
}
.boxed()
});
async move { Ok(future.await?.boxed()) }.boxed()
}
fn use_tool(
fn use_any_tool(
&self,
_request: LanguageModelRequest,
_name: String,