Ollama improvements (#12921)

Attempt to load the model early on when the user has switched the model.

This is a follow up to #12902

Release Notes:

- N/A
This commit is contained in:
Kyle Kelley 2024-06-12 08:10:51 -07:00 committed by GitHub
parent 113546f766
commit bee3441c78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 67 additions and 7 deletions

View file

@ -62,6 +62,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
)),
};
cx.set_global(provider);
@ -114,6 +115,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
api_url.clone(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
);
}
@ -174,6 +176,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
client.http_client(),
low_speed_timeout_in_seconds.map(Duration::from_secs),
settings_version,
cx,
));
}
}

View file

@ -7,7 +7,8 @@ use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task};
use http::HttpClient;
use ollama::{
get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, Role as OllamaRole,
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
Role as OllamaRole,
};
use std::sync::Arc;
use std::time::Duration;
@ -31,7 +32,17 @@ impl OllamaCompletionProvider {
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
cx: &AppContext,
) -> Self {
cx.spawn({
let api_url = api_url.clone();
let client = http_client.clone();
let model = model.name.clone();
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
})
.detach_and_log_err(cx);
Self {
api_url,
model,
@ -48,7 +59,17 @@ impl OllamaCompletionProvider {
api_url: String,
low_speed_timeout: Option<Duration>,
settings_version: usize,
cx: &AppContext,
) {
cx.spawn({
let api_url = api_url.clone();
let client = self.http_client.clone();
let model = model.name.clone();
|_| async move { preload_model(client.as_ref(), &api_url, &model).await }
})
.detach_and_log_err(cx);
self.model = model;
self.api_url = api_url;
self.low_speed_timeout = low_speed_timeout;
@ -93,7 +114,7 @@ impl OllamaCompletionProvider {
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| OllamaModel::new(&model.name, &model.details.parameter_size))
.map(|model| OllamaModel::new(&model.name))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));