Merge models in local settings with ones listed by ollama
This allows for the scenario where the user doesn't have access to ollama's listing and needs to tell Zed explicitly, by hand
This commit is contained in:
parent
1060d1b301
commit
947781bc48
4 changed files with 224 additions and 13 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -10945,6 +10945,7 @@ name = "ollama"
|
|||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"client",
|
||||
"editor",
|
||||
"futures 0.3.31",
|
||||
"gpui",
|
||||
|
|
|
@ -27,6 +27,7 @@ gpui.workspace = true
|
|||
http_client.workspace = true
|
||||
inline_completion.workspace = true
|
||||
language.workspace = true
|
||||
|
||||
log.workspace = true
|
||||
|
||||
project.workspace = true
|
||||
|
@ -38,6 +39,7 @@ text.workspace = true
|
|||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
client = { workspace = true, features = ["test-support"] }
|
||||
editor = { workspace = true, features = ["test-support"] }
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
http_client = { workspace = true, features = ["test-support"] }
|
||||
|
|
|
@ -13,6 +13,30 @@ use project::Project;
|
|||
|
||||
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
||||
|
||||
// Structure for passing settings model data without circular dependencies
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SettingsModel {
|
||||
pub name: String,
|
||||
pub display_name: Option<String>,
|
||||
pub max_tokens: u64,
|
||||
pub supports_tools: Option<bool>,
|
||||
pub supports_images: Option<bool>,
|
||||
pub supports_thinking: Option<bool>,
|
||||
}
|
||||
|
||||
impl SettingsModel {
|
||||
pub fn to_model(&self) -> Model {
|
||||
Model::new(
|
||||
&self.name,
|
||||
self.display_name.as_deref(),
|
||||
Some(self.max_tokens),
|
||||
self.supports_tools,
|
||||
self.supports_images,
|
||||
self.supports_thinking,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Global Ollama service for managing models across all providers
|
||||
pub struct OllamaService {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
|
@ -61,6 +85,19 @@ impl OllamaService {
|
|||
self.restart_fetch_models_task(cx);
|
||||
}
|
||||
|
||||
pub fn set_settings_models(
|
||||
&mut self,
|
||||
settings_models: Vec<SettingsModel>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// Convert settings models to our Model type
|
||||
self.available_models = settings_models
|
||||
.into_iter()
|
||||
.map(|settings_model| settings_model.to_model())
|
||||
.collect();
|
||||
self.restart_fetch_models_task(cx);
|
||||
}
|
||||
|
||||
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
|
||||
self.fetch_models_task = Some(self.fetch_models(cx));
|
||||
}
|
||||
|
@ -70,15 +107,27 @@ impl OllamaService {
|
|||
let api_url = self.api_url.clone();
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
|
||||
// Get the current settings models to merge with API models
|
||||
let settings_models = this.update(cx, |this, _cx| {
|
||||
// Get just the names of models from settings to avoid duplicates
|
||||
this.available_models
|
||||
.iter()
|
||||
.map(|m| m.name.clone())
|
||||
.collect::<std::collections::HashSet<_>>()
|
||||
})?;
|
||||
|
||||
// Fetch models from API
|
||||
let api_models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
|
||||
Ok(models) => models,
|
||||
Err(_) => return Ok(()), // Silently fail and use empty list
|
||||
Err(_) => return Ok(()), // Silently fail if API is unavailable
|
||||
};
|
||||
|
||||
let tasks = models
|
||||
let tasks = api_models
|
||||
.into_iter()
|
||||
// Filter out embedding models
|
||||
.filter(|model| !model.name.contains("-embed"))
|
||||
// Filter out models that are already defined in settings
|
||||
.filter(|model| !settings_models.contains(&model.name))
|
||||
.map(|model| {
|
||||
let http_client = Arc::clone(&http_client);
|
||||
let api_url = api_url.clone();
|
||||
|
@ -98,8 +147,8 @@ impl OllamaService {
|
|||
}
|
||||
});
|
||||
|
||||
// Rate-limit capability fetches
|
||||
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
|
||||
// Rate-limit capability fetches for API-discovered models
|
||||
let api_discovered_models: Vec<_> = futures::stream::iter(tasks)
|
||||
.buffer_unordered(5)
|
||||
.collect::<Vec<Result<_>>>()
|
||||
.await
|
||||
|
@ -107,10 +156,11 @@ impl OllamaService {
|
|||
.collect::<Result<Vec<_>>>()
|
||||
.unwrap_or_default();
|
||||
|
||||
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.available_models = ollama_models;
|
||||
// Append API-discovered models to existing settings models
|
||||
this.available_models.extend(api_discovered_models);
|
||||
// Sort all models by name
|
||||
this.available_models.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
cx.notify();
|
||||
})?;
|
||||
|
||||
|
@ -397,6 +447,7 @@ mod tests {
|
|||
|
||||
use gpui::{AppContext, TestAppContext};
|
||||
|
||||
use client;
|
||||
use language::Buffer;
|
||||
use project::Project;
|
||||
use settings::SettingsStore;
|
||||
|
@ -406,6 +457,7 @@ mod tests {
|
|||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
client::init_settings(cx);
|
||||
language::init(cx);
|
||||
editor::init_settings(cx);
|
||||
Project::init_settings(cx);
|
||||
|
@ -930,4 +982,125 @@ mod tests {
|
|||
assert_eq!(editor.text(cx), "");
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_settings_model_merging(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client that returns some API models
|
||||
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
||||
|
||||
// Mock /api/tags response (list models)
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "api-model-1",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": ["llama"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "shared-model",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 2000000,
|
||||
"digest": "def456",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": ["llama"],
|
||||
"parameter_size": "13B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
|
||||
// Mock /api/show responses for each model
|
||||
let show_response = serde_json::json!({
|
||||
"capabilities": ["tools", "vision"]
|
||||
});
|
||||
fake_http_client.set_response("/api/show", show_response.to_string());
|
||||
|
||||
// Create service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
// Add settings models (including one that overlaps with API)
|
||||
let settings_models = vec![
|
||||
SettingsModel {
|
||||
name: "custom-model-1".to_string(),
|
||||
display_name: Some("Custom Model 1".to_string()),
|
||||
max_tokens: 4096,
|
||||
supports_tools: Some(true),
|
||||
supports_images: Some(false),
|
||||
supports_thinking: Some(false),
|
||||
},
|
||||
SettingsModel {
|
||||
name: "shared-model".to_string(), // This should take precedence over API
|
||||
display_name: Some("Custom Shared Model".to_string()),
|
||||
max_tokens: 8192,
|
||||
supports_tools: Some(true),
|
||||
supports_images: Some(true),
|
||||
supports_thinking: Some(true),
|
||||
},
|
||||
];
|
||||
|
||||
cx.update(|cx| {
|
||||
service.update(cx, |service, cx| {
|
||||
service.set_settings_models(settings_models, cx);
|
||||
});
|
||||
});
|
||||
|
||||
// Wait for models to be fetched and merged
|
||||
cx.run_until_parked();
|
||||
|
||||
// Verify merged models
|
||||
let models = cx.update(|cx| service.read(cx).available_models().to_vec());
|
||||
|
||||
assert_eq!(models.len(), 3); // 2 settings models + 1 unique API model
|
||||
|
||||
// Models should be sorted alphabetically, so check by name
|
||||
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert_eq!(
|
||||
model_names,
|
||||
vec!["api-model-1", "custom-model-1", "shared-model"]
|
||||
);
|
||||
|
||||
// Check custom model from settings
|
||||
let custom_model = models.iter().find(|m| m.name == "custom-model-1").unwrap();
|
||||
assert_eq!(
|
||||
custom_model.display_name,
|
||||
Some("Custom Model 1".to_string())
|
||||
);
|
||||
assert_eq!(custom_model.max_tokens, 4096);
|
||||
|
||||
// Settings model should override API model for shared-model
|
||||
let shared_model = models.iter().find(|m| m.name == "shared-model").unwrap();
|
||||
assert_eq!(
|
||||
shared_model.display_name,
|
||||
Some("Custom Shared Model".to_string())
|
||||
);
|
||||
assert_eq!(shared_model.max_tokens, 8192);
|
||||
assert_eq!(shared_model.supports_tools, Some(true));
|
||||
assert_eq!(shared_model.supports_vision, Some(true));
|
||||
assert_eq!(shared_model.supports_thinking, Some(true));
|
||||
|
||||
// API-only model should be included
|
||||
let api_model = models.iter().find(|m| m.name == "api-model-1").unwrap();
|
||||
assert!(api_model.display_name.is_none()); // API models don't have custom display names
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
|||
|
||||
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
||||
use language_models::AllLanguageModelSettings;
|
||||
use ollama::{OllamaCompletionProvider, OllamaService};
|
||||
use ollama::{OllamaCompletionProvider, OllamaService, SettingsModel};
|
||||
use settings::{Settings as _, SettingsStore};
|
||||
use smol::stream::StreamExt;
|
||||
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
||||
|
@ -19,8 +19,30 @@ use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
|
|||
|
||||
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||
// Initialize global Ollama service
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let ollama_service = OllamaService::new(client.http_client(), settings.api_url.clone(), cx);
|
||||
let (api_url, settings_models) = {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let api_url = settings.api_url.clone();
|
||||
let settings_models: Vec<SettingsModel> = settings
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|model| SettingsModel {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_images,
|
||||
supports_thinking: model.supports_thinking,
|
||||
})
|
||||
.collect();
|
||||
(api_url, settings_models)
|
||||
};
|
||||
|
||||
let ollama_service = OllamaService::new(client.http_client(), api_url, cx);
|
||||
|
||||
ollama_service.update(cx, |service, cx| {
|
||||
service.set_settings_models(settings_models, cx);
|
||||
});
|
||||
|
||||
OllamaService::set_global(ollama_service, cx);
|
||||
|
||||
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
|
||||
|
@ -144,10 +166,23 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
|||
}
|
||||
} else if provider == EditPredictionProvider::Ollama {
|
||||
// Update global Ollama service when settings change
|
||||
let _settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
let settings_models: Vec<SettingsModel> = settings
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|model| SettingsModel {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_images,
|
||||
supports_thinking: model.supports_thinking,
|
||||
})
|
||||
.collect();
|
||||
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
service.set_settings_models(settings_models, cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue