Allow customization of the model used for tool calling (#15479)
We also eliminate the `completion` crate and moved its logic into `LanguageModelRegistry`. Release Notes: - N/A --------- Co-authored-by: Nathan <nathan@zed.dev>
This commit is contained in:
parent
1bfea9d443
commit
99bc90a372
32 changed files with 478 additions and 691 deletions
|
@ -12,7 +12,7 @@ use ui::{prelude::*, ButtonLike, ElevationIndex};
|
|||
use crate::{
|
||||
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelProviderState, LanguageModelRequest, Role,
|
||||
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
|
||||
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
|
||||
|
@ -39,7 +39,7 @@ struct State {
|
|||
}
|
||||
|
||||
impl State {
|
||||
fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
|
||||
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
|
@ -80,37 +80,10 @@ impl OllamaLanguageModelProvider {
|
|||
}),
|
||||
}),
|
||||
};
|
||||
this.fetch_models(cx).detach();
|
||||
this.state
|
||||
.update(cx, |state, cx| state.fetch_models(cx).detach());
|
||||
this
|
||||
}
|
||||
|
||||
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = settings.api_url.clone();
|
||||
|
||||
let state = self.state.clone();
|
||||
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
|
||||
cx.spawn(|mut cx| async move {
|
||||
let models = get_models(http_client.as_ref(), &api_url, None).await?;
|
||||
|
||||
let mut models: Vec<ollama::Model> = models
|
||||
.into_iter()
|
||||
// Since there is no metadata from the Ollama API
|
||||
// indicating which models are embedding models,
|
||||
// simply filter out models with "-embed" in their name
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
.map(|model| ollama::Model::new(&model.name))
|
||||
.collect();
|
||||
|
||||
models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
state.update(&mut cx, |this, cx| {
|
||||
this.available_models = models;
|
||||
cx.notify();
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProviderState for OllamaLanguageModelProvider {
|
||||
|
@ -140,6 +113,7 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
id: LanguageModelId::from(model.name.clone()),
|
||||
model: model.clone(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.collect()
|
||||
|
@ -158,11 +132,11 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
!self.state.read(cx).available_models.is_empty()
|
||||
}
|
||||
|
||||
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
if self.is_authenticated(cx) {
|
||||
Task::ready(Ok(()))
|
||||
} else {
|
||||
self.fetch_models(cx)
|
||||
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -176,8 +150,8 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
|
|||
.into()
|
||||
}
|
||||
|
||||
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
|
||||
self.fetch_models(cx)
|
||||
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
|
||||
self.state.update(cx, |state, cx| state.fetch_models(cx))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -185,6 +159,7 @@ pub struct OllamaLanguageModel {
|
|||
id: LanguageModelId,
|
||||
model: ollama::Model,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
|
||||
impl OllamaLanguageModel {
|
||||
|
@ -235,14 +210,14 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
LanguageModelProviderName(PROVIDER_NAME.into())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("ollama/{}", self.model.id())
|
||||
}
|
||||
|
||||
fn max_token_count(&self) -> usize {
|
||||
self.model.max_token_count()
|
||||
}
|
||||
|
||||
fn count_tokens(
|
||||
&self,
|
||||
request: LanguageModelRequest,
|
||||
|
@ -275,10 +250,10 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
};
|
||||
|
||||
async move {
|
||||
let request =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
|
||||
let response = request.await?;
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response =
|
||||
stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
|
||||
.await?;
|
||||
let stream = response
|
||||
.filter_map(|response| async move {
|
||||
match response {
|
||||
|
@ -295,11 +270,12 @@ impl LanguageModel for OllamaLanguageModel {
|
|||
})
|
||||
.boxed();
|
||||
Ok(stream)
|
||||
}
|
||||
.boxed()
|
||||
});
|
||||
|
||||
async move { Ok(future.await?.boxed()) }.boxed()
|
||||
}
|
||||
|
||||
fn use_tool(
|
||||
fn use_any_tool(
|
||||
&self,
|
||||
_request: LanguageModelRequest,
|
||||
_name: String,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue