Auto detect models WIP
This commit is contained in:
parent
5a1506c3c2
commit
0bdb42e65d
8 changed files with 952 additions and 128 deletions
|
@ -21,6 +21,8 @@ use language::{
|
|||
};
|
||||
use language_models::AllLanguageModelSettings;
|
||||
|
||||
use ollama;
|
||||
|
||||
use paths;
|
||||
use regex::Regex;
|
||||
use settings::{Settings, SettingsStore, update_settings_file};
|
||||
|
@ -413,6 +415,10 @@ impl InlineCompletionButton {
|
|||
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
|
||||
.detach();
|
||||
|
||||
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||
cx.observe(&service, |_, _, cx| cx.notify()).detach();
|
||||
}
|
||||
|
||||
Self {
|
||||
editor_subscription: None,
|
||||
editor_enabled: None,
|
||||
|
@ -858,8 +864,30 @@ impl InlineCompletionButton {
|
|||
let settings = AllLanguageModelSettings::get_global(cx);
|
||||
let ollama_settings = &settings.ollama;
|
||||
|
||||
// Clone needed values to avoid borrowing issues
|
||||
let available_models = ollama_settings.available_models.clone();
|
||||
// Get models from both settings and global service discovery
|
||||
let mut available_models = ollama_settings.available_models.clone();
|
||||
|
||||
// Add discovered models from the global Ollama service
|
||||
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||
let discovered_models = service.read(cx).available_models();
|
||||
for model in discovered_models {
|
||||
// Convert from ollama::Model to language_models AvailableModel
|
||||
let available_model = language_models::provider::ollama::AvailableModel {
|
||||
name: model.name.clone(),
|
||||
display_name: model.display_name.clone(),
|
||||
max_tokens: model.max_tokens,
|
||||
keep_alive: model.keep_alive.clone(),
|
||||
supports_tools: model.supports_tools,
|
||||
supports_images: model.supports_vision,
|
||||
supports_thinking: model.supports_thinking,
|
||||
};
|
||||
|
||||
// Add if not already in settings (settings take precedence)
|
||||
if !available_models.iter().any(|m| m.name == model.name) {
|
||||
available_models.push(available_model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// API URL configuration - only show if Ollama settings exist in the user's config
|
||||
let menu = if Self::ollama_settings_exist(cx) {
|
||||
|
@ -878,7 +906,7 @@ impl InlineCompletionButton {
|
|||
let menu = menu.separator().header("Available Models");
|
||||
|
||||
// Add each available model as a menu entry
|
||||
available_models.iter().fold(menu, |menu, model| {
|
||||
let menu = available_models.iter().fold(menu, |menu, model| {
|
||||
let model_name = model.display_name.as_ref().unwrap_or(&model.name);
|
||||
let is_current = available_models
|
||||
.first()
|
||||
|
@ -898,6 +926,13 @@ impl InlineCompletionButton {
|
|||
}
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
// Add refresh models option
|
||||
menu.separator().entry("Refresh Models", None, {
|
||||
move |_window, cx| {
|
||||
Self::refresh_ollama_models(cx);
|
||||
}
|
||||
})
|
||||
} else {
|
||||
menu.separator()
|
||||
|
@ -908,6 +943,11 @@ impl InlineCompletionButton {
|
|||
Self::open_ollama_settings(fs.clone(), window, cx);
|
||||
}
|
||||
})
|
||||
.entry("Refresh Models", None, {
|
||||
move |_window, cx| {
|
||||
Self::refresh_ollama_models(cx);
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
// Use the common language settings menu
|
||||
|
@ -997,6 +1037,14 @@ impl InlineCompletionButton {
|
|||
});
|
||||
}
|
||||
|
||||
fn refresh_ollama_models(cx: &mut App) {
|
||||
if let Some(service) = ollama::OllamaService::global(cx) {
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
|
||||
let editor = editor.read(cx);
|
||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||
|
@ -1188,3 +1236,359 @@ fn toggle_edit_prediction_mode(fs: Arc<dyn Fs>, mode: EditPredictionsMode, cx: &
|
|||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use clock::FakeSystemClock;
|
||||
use gpui::TestAppContext;
|
||||
use http_client;
|
||||
use language_models::provider::ollama::AvailableModel;
|
||||
use ollama::{OllamaService, fake::FakeHttpClient};
|
||||
use settings::SettingsStore;
|
||||
use std::sync::Arc;
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
gpui_tokio::init(cx);
|
||||
theme::init(theme::LoadThemes::JustBase, cx);
|
||||
language::init(cx);
|
||||
language_settings::init(cx);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ollama_menu_shows_discovered_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client with mock models response
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
// Mock /api/tags response
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "qwen2.5-coder:3b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "3B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "codellama:7b-code",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 2000000,
|
||||
"digest": "def456",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "codellama",
|
||||
"families": ["codellama"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
|
||||
// Mock /api/show response
|
||||
let capabilities = serde_json::json!({
|
||||
"capabilities": ["tools"]
|
||||
});
|
||||
fake_http_client.set_response("/api/show", capabilities.to_string());
|
||||
|
||||
// Create and set global Ollama service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
// Wait for model discovery
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models are accessible through the service
|
||||
cx.update(|cx| {
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
let discovered_models = service.read(cx).available_models();
|
||||
assert_eq!(discovered_models.len(), 2);
|
||||
|
||||
let model_names: Vec<&str> =
|
||||
discovered_models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||
assert!(model_names.contains(&"codellama:7b-code"));
|
||||
} else {
|
||||
panic!("Global service should be available");
|
||||
}
|
||||
});
|
||||
|
||||
// Verify the global service has the expected models
|
||||
service.read_with(cx, |service, _| {
|
||||
let models = service.available_models();
|
||||
assert_eq!(models.len(), 2);
|
||||
|
||||
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
||||
assert!(model_names.contains(&"codellama:7b-code"));
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ollama_menu_shows_service_models(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
// Create fake HTTP client with models
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "qwen2.5-coder:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "abc123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "qwen2",
|
||||
"families": ["qwen2"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
fake_http_client.set_response(
|
||||
"/api/show",
|
||||
serde_json::json!({"capabilities": []}).to_string(),
|
||||
);
|
||||
|
||||
// Create and set global service
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(fake_http_client, "http://localhost:11434".to_string(), cx)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Test that discovered models are accessible
|
||||
cx.update(|cx| {
|
||||
if let Some(service) = OllamaService::global(cx) {
|
||||
let discovered_models = service.read(cx).available_models();
|
||||
assert_eq!(discovered_models.len(), 1);
|
||||
assert_eq!(discovered_models[0].name, "qwen2.5-coder:7b");
|
||||
} else {
|
||||
panic!("Global service should be available");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_ollama_menu_refreshes_on_service_update(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
// Initially empty models
|
||||
fake_http_client.set_response("/api/tags", serde_json::json!({"models": []}).to_string());
|
||||
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify the service subscription mechanism works by creating a button
|
||||
let _button = cx.update(|cx| {
|
||||
let fs = fs::FakeFs::new(cx.background_executor().clone());
|
||||
let user_store = cx.new(|cx| {
|
||||
client::UserStore::new(
|
||||
Arc::new(http_client::FakeHttpClient::create(|_| {
|
||||
Box::pin(async { Err(anyhow::anyhow!("not implemented")) })
|
||||
})),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
let popover_handle = PopoverMenuHandle::default();
|
||||
|
||||
cx.new(|cx| InlineCompletionButton::new(fs, user_store, popover_handle, cx))
|
||||
});
|
||||
|
||||
// Verify initially no models
|
||||
service.read_with(cx, |service, _| {
|
||||
assert_eq!(service.available_models().len(), 0);
|
||||
});
|
||||
|
||||
// Update mock to return models
|
||||
let models_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "phi3:mini",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 500000,
|
||||
"digest": "xyz789",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "phi3",
|
||||
"families": ["phi3"],
|
||||
"parameter_size": "3.8B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", models_response.to_string());
|
||||
fake_http_client.set_response(
|
||||
"/api/show",
|
||||
serde_json::json!({"capabilities": []}).to_string(),
|
||||
);
|
||||
|
||||
// Trigger refresh
|
||||
service.update(cx, |service, cx| {
|
||||
service.refresh_models(cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models were refreshed
|
||||
service.read_with(cx, |service, _| {
|
||||
let models = service.available_models();
|
||||
assert_eq!(models.len(), 1);
|
||||
assert_eq!(models[0].name, "phi3:mini");
|
||||
});
|
||||
|
||||
// The button should have been notified and will rebuild its menu with new models
|
||||
// when next requested (this tests the subscription mechanism)
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_refresh_models_button_functionality(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fake_http_client = Arc::new(FakeHttpClient::new());
|
||||
|
||||
// Start with one model
|
||||
let initial_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "mistral:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "initial123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "mistral",
|
||||
"families": ["mistral"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", initial_response.to_string());
|
||||
fake_http_client.set_response(
|
||||
"/api/show",
|
||||
serde_json::json!({"capabilities": []}).to_string(),
|
||||
);
|
||||
|
||||
let service = cx.update(|cx| {
|
||||
OllamaService::new(
|
||||
fake_http_client.clone(),
|
||||
"http://localhost:11434".to_string(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
|
||||
cx.update(|cx| {
|
||||
OllamaService::set_global(service.clone(), cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify initial model
|
||||
service.read_with(cx, |service, _| {
|
||||
assert_eq!(service.available_models().len(), 1);
|
||||
assert_eq!(service.available_models()[0].name, "mistral:7b");
|
||||
});
|
||||
|
||||
// Update mock to simulate new model available
|
||||
let updated_response = serde_json::json!({
|
||||
"models": [
|
||||
{
|
||||
"name": "mistral:7b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 1000000,
|
||||
"digest": "initial123",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "mistral",
|
||||
"families": ["mistral"],
|
||||
"parameter_size": "7B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "gemma2:9b",
|
||||
"modified_at": "2024-01-01T00:00:00Z",
|
||||
"size": 2000000,
|
||||
"digest": "new456",
|
||||
"details": {
|
||||
"format": "gguf",
|
||||
"family": "gemma2",
|
||||
"families": ["gemma2"],
|
||||
"parameter_size": "9B",
|
||||
"quantization_level": "Q4_0"
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
fake_http_client.set_response("/api/tags", updated_response.to_string());
|
||||
|
||||
// Simulate clicking "Refresh Models" button
|
||||
cx.update(|cx| {
|
||||
InlineCompletionButton::refresh_ollama_models(cx);
|
||||
});
|
||||
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
// Verify models were refreshed
|
||||
service.read_with(cx, |service, _| {
|
||||
let models = service.available_models();
|
||||
assert_eq!(models.len(), 2);
|
||||
|
||||
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
||||
assert!(model_names.contains(&"mistral:7b"));
|
||||
assert!(model_names.contains(&"gemma2:9b"));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue