Merge models in local settings with ones listed by ollama

This allows for the scenario where the user doesn't have access to ollama's listing and needs to tell Zed explicitly, by hand
This commit is contained in:
Oliver Azevedo Barnes 2025-07-31 13:42:53 +01:00
parent 1060d1b301
commit 947781bc48
No known key found for this signature in database
4 changed files with 224 additions and 13 deletions

1
Cargo.lock generated
View file

@ -10945,6 +10945,7 @@ name = "ollama"
version = "0.1.0"
dependencies = [
"anyhow",
"client",
"editor",
"futures 0.3.31",
"gpui",

View file

@ -27,6 +27,7 @@ gpui.workspace = true
http_client.workspace = true
inline_completion.workspace = true
language.workspace = true
log.workspace = true
project.workspace = true
@ -38,6 +39,7 @@ text.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
client = { workspace = true, features = ["test-support"] }
editor = { workspace = true, features = ["test-support"] }
gpui = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }

View file

@ -13,6 +13,30 @@ use project::Project;
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
// Structure for passing settings model data without circular dependencies
#[derive(Clone, Debug)]
pub struct SettingsModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: u64,
pub supports_tools: Option<bool>,
pub supports_images: Option<bool>,
pub supports_thinking: Option<bool>,
}
impl SettingsModel {
pub fn to_model(&self) -> Model {
Model::new(
&self.name,
self.display_name.as_deref(),
Some(self.max_tokens),
self.supports_tools,
self.supports_images,
self.supports_thinking,
)
}
}
// Global Ollama service for managing models across all providers
pub struct OllamaService {
http_client: Arc<dyn HttpClient>,
@ -61,6 +85,19 @@ impl OllamaService {
self.restart_fetch_models_task(cx);
}
pub fn set_settings_models(
&mut self,
settings_models: Vec<SettingsModel>,
cx: &mut Context<Self>,
) {
// Convert settings models to our Model type
self.available_models = settings_models
.into_iter()
.map(|settings_model| settings_model.to_model())
.collect();
self.restart_fetch_models_task(cx);
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
self.fetch_models_task = Some(self.fetch_models(cx));
}
@ -70,15 +107,27 @@ impl OllamaService {
let api_url = self.api_url.clone();
cx.spawn(async move |this, cx| {
let models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
// Get the current settings models to merge with API models
let settings_models = this.update(cx, |this, _cx| {
// Get just the names of models from settings to avoid duplicates
this.available_models
.iter()
.map(|m| m.name.clone())
.collect::<std::collections::HashSet<_>>()
})?;
// Fetch models from API
let api_models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
Ok(models) => models,
Err(_) => return Ok(()), // Silently fail and use empty list
Err(_) => return Ok(()), // Silently fail if API is unavailable
};
let tasks = models
let tasks = api_models
.into_iter()
// Filter out embedding models
.filter(|model| !model.name.contains("-embed"))
// Filter out models that are already defined in settings
.filter(|model| !settings_models.contains(&model.name))
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
@ -98,8 +147,8 @@ impl OllamaService {
}
});
// Rate-limit capability fetches
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
// Rate-limit capability fetches for API-discovered models
let api_discovered_models: Vec<_> = futures::stream::iter(tasks)
.buffer_unordered(5)
.collect::<Vec<Result<_>>>()
.await
@ -107,10 +156,11 @@ impl OllamaService {
.collect::<Result<Vec<_>>>()
.unwrap_or_default();
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
this.available_models = ollama_models;
// Append API-discovered models to existing settings models
this.available_models.extend(api_discovered_models);
// Sort all models by name
this.available_models.sort_by(|a, b| a.name.cmp(&b.name));
cx.notify();
})?;
@ -397,6 +447,7 @@ mod tests {
use gpui::{AppContext, TestAppContext};
use client;
use language::Buffer;
use project::Project;
use settings::SettingsStore;
@ -406,6 +457,7 @@ mod tests {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
theme::init(theme::LoadThemes::JustBase, cx);
client::init_settings(cx);
language::init(cx);
editor::init_settings(cx);
Project::init_settings(cx);
@ -930,4 +982,125 @@ mod tests {
assert_eq!(editor.text(cx), "");
});
}
#[gpui::test]
async fn test_settings_model_merging(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client that returns some API models
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Mock /api/tags response (list models)
let models_response = serde_json::json!({
"models": [
{
"name": "api-model-1",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "llama",
"families": ["llama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "shared-model",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "llama",
"families": ["llama"],
"parameter_size": "13B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show responses for each model
let show_response = serde_json::json!({
"capabilities": ["tools", "vision"]
});
fake_http_client.set_response("/api/show", show_response.to_string());
// Create service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
// Add settings models (including one that overlaps with API)
let settings_models = vec![
SettingsModel {
name: "custom-model-1".to_string(),
display_name: Some("Custom Model 1".to_string()),
max_tokens: 4096,
supports_tools: Some(true),
supports_images: Some(false),
supports_thinking: Some(false),
},
SettingsModel {
name: "shared-model".to_string(), // This should take precedence over API
display_name: Some("Custom Shared Model".to_string()),
max_tokens: 8192,
supports_tools: Some(true),
supports_images: Some(true),
supports_thinking: Some(true),
},
];
cx.update(|cx| {
service.update(cx, |service, cx| {
service.set_settings_models(settings_models, cx);
});
});
// Wait for models to be fetched and merged
cx.run_until_parked();
// Verify merged models
let models = cx.update(|cx| service.read(cx).available_models().to_vec());
assert_eq!(models.len(), 3); // 2 settings models + 1 unique API model
// Models should be sorted alphabetically, so check by name
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert_eq!(
model_names,
vec!["api-model-1", "custom-model-1", "shared-model"]
);
// Check custom model from settings
let custom_model = models.iter().find(|m| m.name == "custom-model-1").unwrap();
assert_eq!(
custom_model.display_name,
Some("Custom Model 1".to_string())
);
assert_eq!(custom_model.max_tokens, 4096);
// Settings model should override API model for shared-model
let shared_model = models.iter().find(|m| m.name == "shared-model").unwrap();
assert_eq!(
shared_model.display_name,
Some("Custom Shared Model".to_string())
);
assert_eq!(shared_model.max_tokens, 8192);
assert_eq!(shared_model.supports_tools, Some(true));
assert_eq!(shared_model.supports_vision, Some(true));
assert_eq!(shared_model.supports_thinking, Some(true));
// API-only model should be included
let api_model = models.iter().find(|m| m.name == "api-model-1").unwrap();
assert!(api_model.display_name.is_none()); // API models don't have custom display names
}
}

View file

@ -6,7 +6,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::AllLanguageModelSettings;
use ollama::{OllamaCompletionProvider, OllamaService};
use ollama::{OllamaCompletionProvider, OllamaService, SettingsModel};
use settings::{Settings as _, SettingsStore};
use smol::stream::StreamExt;
use std::{cell::RefCell, rc::Rc, sync::Arc};
@ -19,8 +19,30 @@ use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
// Initialize global Ollama service
let (api_url, settings_models) = {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let ollama_service = OllamaService::new(client.http_client(), settings.api_url.clone(), cx);
let api_url = settings.api_url.clone();
let settings_models: Vec<SettingsModel> = settings
.available_models
.iter()
.map(|model| SettingsModel {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
supports_tools: model.supports_tools,
supports_images: model.supports_images,
supports_thinking: model.supports_thinking,
})
.collect();
(api_url, settings_models)
};
let ollama_service = OllamaService::new(client.http_client(), api_url, cx);
ollama_service.update(cx, |service, cx| {
service.set_settings_models(settings_models, cx);
});
OllamaService::set_global(ollama_service, cx);
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
@ -144,10 +166,23 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
} else if provider == EditPredictionProvider::Ollama {
// Update global Ollama service when settings change
let _settings = &AllLanguageModelSettings::get_global(cx).ollama;
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
if let Some(service) = OllamaService::global(cx) {
let settings_models: Vec<SettingsModel> = settings
.available_models
.iter()
.map(|model| SettingsModel {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
supports_tools: model.supports_tools,
supports_images: model.supports_images,
supports_thinking: model.supports_thinking,
})
.collect();
service.update(cx, |service, cx| {
service.refresh_models(cx);
service.set_settings_models(settings_models, cx);
});
}
}