Auto detect models WIP

This commit is contained in:
Oliver Azevedo Barnes 2025-07-25 10:21:32 +01:00
parent 5a1506c3c2
commit 0bdb42e65d
No known key found for this signature in database
8 changed files with 952 additions and 128 deletions

View file

@ -21,6 +21,8 @@ use language::{
};
use language_models::AllLanguageModelSettings;
use ollama;
use paths;
use regex::Regex;
use settings::{Settings, SettingsStore, update_settings_file};
@ -413,6 +415,10 @@ impl InlineCompletionButton {
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
.detach();
if let Some(service) = ollama::OllamaService::global(cx) {
cx.observe(&service, |_, _, cx| cx.notify()).detach();
}
Self {
editor_subscription: None,
editor_enabled: None,
@ -858,8 +864,30 @@ impl InlineCompletionButton {
let settings = AllLanguageModelSettings::get_global(cx);
let ollama_settings = &settings.ollama;
// Clone needed values to avoid borrowing issues
let available_models = ollama_settings.available_models.clone();
// Get models from both settings and global service discovery
let mut available_models = ollama_settings.available_models.clone();
// Add discovered models from the global Ollama service
if let Some(service) = ollama::OllamaService::global(cx) {
let discovered_models = service.read(cx).available_models();
for model in discovered_models {
// Convert from ollama::Model to language_models AvailableModel
let available_model = language_models::provider::ollama::AvailableModel {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: model.keep_alive.clone(),
supports_tools: model.supports_tools,
supports_images: model.supports_vision,
supports_thinking: model.supports_thinking,
};
// Add if not already in settings (settings take precedence)
if !available_models.iter().any(|m| m.name == model.name) {
available_models.push(available_model);
}
}
}
// API URL configuration - only show if Ollama settings exist in the user's config
let menu = if Self::ollama_settings_exist(cx) {
@ -878,7 +906,7 @@ impl InlineCompletionButton {
let menu = menu.separator().header("Available Models");
// Add each available model as a menu entry
available_models.iter().fold(menu, |menu, model| {
let menu = available_models.iter().fold(menu, |menu, model| {
let model_name = model.display_name.as_ref().unwrap_or(&model.name);
let is_current = available_models
.first()
@ -898,6 +926,13 @@ impl InlineCompletionButton {
}
},
)
});
// Add refresh models option
menu.separator().entry("Refresh Models", None, {
move |_window, cx| {
Self::refresh_ollama_models(cx);
}
})
} else {
menu.separator()
@ -908,6 +943,11 @@ impl InlineCompletionButton {
Self::open_ollama_settings(fs.clone(), window, cx);
}
})
.entry("Refresh Models", None, {
move |_window, cx| {
Self::refresh_ollama_models(cx);
}
})
};
// Use the common language settings menu
@ -997,6 +1037,14 @@ impl InlineCompletionButton {
});
}
fn refresh_ollama_models(cx: &mut App) {
if let Some(service) = ollama::OllamaService::global(cx) {
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
}
}
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
let editor = editor.read(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
@ -1188,3 +1236,359 @@ fn toggle_edit_prediction_mode(fs: Arc<dyn Fs>, mode: EditPredictionsMode, cx: &
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use clock::FakeSystemClock;
use gpui::TestAppContext;
use http_client;
use language_models::provider::ollama::AvailableModel;
use ollama::{OllamaService, fake::FakeHttpClient};
use settings::SettingsStore;
use std::sync::Arc;
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
gpui_tokio::init(cx);
theme::init(theme::LoadThemes::JustBase, cx);
language::init(cx);
language_settings::init(cx);
});
}
#[gpui::test]
async fn test_ollama_menu_shows_discovered_models(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client with mock models response
let fake_http_client = Arc::new(FakeHttpClient::new());
// Mock /api/tags response
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:3b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "3B",
"quantization_level": "Q4_0"
}
},
{
"name": "codellama:7b-code",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "codellama",
"families": ["codellama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show response
let capabilities = serde_json::json!({
"capabilities": ["tools"]
});
fake_http_client.set_response("/api/show", capabilities.to_string());
// Create and set global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Wait for model discovery
cx.background_executor.run_until_parked();
// Verify models are accessible through the service
cx.update(|cx| {
if let Some(service) = OllamaService::global(cx) {
let discovered_models = service.read(cx).available_models();
assert_eq!(discovered_models.len(), 2);
let model_names: Vec<&str> =
discovered_models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(model_names.contains(&"codellama:7b-code"));
} else {
panic!("Global service should be available");
}
});
// Verify the global service has the expected models
service.read_with(cx, |service, _| {
let models = service.available_models();
assert_eq!(models.len(), 2);
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(model_names.contains(&"codellama:7b-code"));
});
}
#[gpui::test]
async fn test_ollama_menu_shows_service_models(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client with models
let fake_http_client = Arc::new(FakeHttpClient::new());
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
fake_http_client.set_response(
"/api/show",
serde_json::json!({"capabilities": []}).to_string(),
);
// Create and set global service
let service = cx.update(|cx| {
OllamaService::new(fake_http_client, "http://localhost:11434".to_string(), cx)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
cx.background_executor.run_until_parked();
// Test that discovered models are accessible
cx.update(|cx| {
if let Some(service) = OllamaService::global(cx) {
let discovered_models = service.read(cx).available_models();
assert_eq!(discovered_models.len(), 1);
assert_eq!(discovered_models[0].name, "qwen2.5-coder:7b");
} else {
panic!("Global service should be available");
}
});
}
#[gpui::test]
async fn test_ollama_menu_refreshes_on_service_update(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(FakeHttpClient::new());
// Initially empty models
fake_http_client.set_response("/api/tags", serde_json::json!({"models": []}).to_string());
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
cx.background_executor.run_until_parked();
// Verify the service subscription mechanism works by creating a button
let _button = cx.update(|cx| {
let fs = fs::FakeFs::new(cx.background_executor().clone());
let user_store = cx.new(|cx| {
client::UserStore::new(
Arc::new(http_client::FakeHttpClient::create(|_| {
Box::pin(async { Err(anyhow::anyhow!("not implemented")) })
})),
cx,
)
});
let popover_handle = PopoverMenuHandle::default();
cx.new(|cx| InlineCompletionButton::new(fs, user_store, popover_handle, cx))
});
// Verify initially no models
service.read_with(cx, |service, _| {
assert_eq!(service.available_models().len(), 0);
});
// Update mock to return models
let models_response = serde_json::json!({
"models": [
{
"name": "phi3:mini",
"modified_at": "2024-01-01T00:00:00Z",
"size": 500000,
"digest": "xyz789",
"details": {
"format": "gguf",
"family": "phi3",
"families": ["phi3"],
"parameter_size": "3.8B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
fake_http_client.set_response(
"/api/show",
serde_json::json!({"capabilities": []}).to_string(),
);
// Trigger refresh
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
service.read_with(cx, |service, _| {
let models = service.available_models();
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "phi3:mini");
});
// The button should have been notified and will rebuild its menu with new models
// when next requested (this tests the subscription mechanism)
}
#[gpui::test]
async fn test_refresh_models_button_functionality(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(FakeHttpClient::new());
// Start with one model
let initial_response = serde_json::json!({
"models": [
{
"name": "mistral:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "initial123",
"details": {
"format": "gguf",
"family": "mistral",
"families": ["mistral"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", initial_response.to_string());
fake_http_client.set_response(
"/api/show",
serde_json::json!({"capabilities": []}).to_string(),
);
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
cx.background_executor.run_until_parked();
// Verify initial model
service.read_with(cx, |service, _| {
assert_eq!(service.available_models().len(), 1);
assert_eq!(service.available_models()[0].name, "mistral:7b");
});
// Update mock to simulate new model available
let updated_response = serde_json::json!({
"models": [
{
"name": "mistral:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "initial123",
"details": {
"format": "gguf",
"family": "mistral",
"families": ["mistral"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "gemma2:9b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "new456",
"details": {
"format": "gguf",
"family": "gemma2",
"families": ["gemma2"],
"parameter_size": "9B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", updated_response.to_string());
// Simulate clicking "Refresh Models" button
cx.update(|cx| {
InlineCompletionButton::refresh_ollama_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
service.read_with(cx, |service, _| {
let models = service.available_models();
assert_eq!(models.len(), 2);
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"mistral:7b"));
assert!(model_names.contains(&"gemma2:9b"));
});
}
}