Auto detect models WIP

This commit is contained in:
Oliver Azevedo Barnes 2025-07-25 10:21:32 +01:00
parent 5a1506c3c2
commit 0bdb42e65d
No known key found for this signature in database
8 changed files with 952 additions and 128 deletions

View file

@ -1,43 +1,172 @@
use crate::{GenerateOptions, GenerateRequest, generate};
use crate::{GenerateOptions, GenerateRequest, Model, generate};
use anyhow::{Context as AnyhowContext, Result};
use futures::StreamExt;
use std::{path::Path, sync::Arc, time::Duration};
use gpui::{App, Context, Entity, EntityId, Task};
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task};
use http_client::HttpClient;
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
use language::{Anchor, Buffer, ToOffset};
use settings::SettingsStore;
use project::Project;
use std::{path::Path, sync::Arc, time::Duration};
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
pub struct OllamaCompletionProvider {
// Global Ollama service for managing models across all providers
pub struct OllamaService {
http_client: Arc<dyn HttpClient>,
api_url: String,
available_models: Vec<Model>,
fetch_models_task: Option<Task<Result<()>>>,
_settings_subscription: Subscription,
}
impl OllamaService {
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, cx: &mut App) -> Entity<Self> {
cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
move |this: &mut OllamaService, cx| {
this.restart_fetch_models_task(cx);
}
});
let mut service = Self {
http_client,
api_url,
available_models: Vec::new(),
fetch_models_task: None,
_settings_subscription: subscription,
};
service.restart_fetch_models_task(cx);
service
})
}
pub fn global(cx: &App) -> Option<Entity<Self>> {
cx.try_global::<GlobalOllamaService>()
.map(|service| service.0.clone())
}
pub fn set_global(service: Entity<Self>, cx: &mut App) {
cx.set_global(GlobalOllamaService(service));
}
pub fn available_models(&self) -> &[Model] {
&self.available_models
}
pub fn refresh_models(&mut self, cx: &mut Context<Self>) {
self.restart_fetch_models_task(cx);
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
self.fetch_models_task = Some(self.fetch_models(cx));
}
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let http_client = Arc::clone(&self.http_client);
let api_url = self.api_url.clone();
cx.spawn(async move |this, cx| {
let models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
Ok(models) => models,
Err(_) => return Ok(()), // Silently fail and use empty list
};
let tasks = models
.into_iter()
// Filter out embedding models
.filter(|model| !model.name.contains("-embed"))
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
async move {
let name = model.name.as_str();
let capabilities =
crate::show_model(http_client.as_ref(), &api_url, name).await?;
let ollama_model = Model::new(
name,
None,
None,
Some(capabilities.supports_tools()),
Some(capabilities.supports_vision()),
Some(capabilities.supports_thinking()),
);
Ok(ollama_model)
}
});
// Rate-limit capability fetches
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
.buffer_unordered(5)
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()
.unwrap_or_default();
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
this.available_models = ollama_models;
cx.notify();
})?;
Ok(())
})
}
}
struct GlobalOllamaService(Entity<OllamaService>);
impl Global for GlobalOllamaService {}
pub struct OllamaCompletionProvider {
model: String,
buffer_id: Option<EntityId>,
file_extension: Option<String>,
current_completion: Option<String>,
pending_refresh: Option<Task<Result<()>>>,
api_key: Option<String>,
_service_subscription: Option<Subscription>,
}
impl OllamaCompletionProvider {
pub fn new(
http_client: Arc<dyn HttpClient>,
api_url: String,
model: String,
api_key: Option<String>,
) -> Self {
pub fn new(model: String, api_key: Option<String>, cx: &mut Context<Self>) -> Self {
let subscription = if let Some(service) = OllamaService::global(cx) {
Some(cx.observe(&service, |_this, _service, cx| {
cx.notify();
}))
} else {
None
};
Self {
http_client,
api_url,
model,
buffer_id: None,
file_extension: None,
current_completion: None,
pending_refresh: None,
api_key,
_service_subscription: subscription,
}
}
pub fn available_models(&self, cx: &App) -> Vec<Model> {
if let Some(service) = OllamaService::global(cx) {
service.read(cx).available_models().to_vec()
} else {
Vec::new()
}
}
pub fn refresh_models(&self, cx: &mut App) {
if let Some(service) = OllamaService::global(cx) {
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
}
}
@ -104,14 +233,28 @@ impl EditPredictionProvider for OllamaCompletionProvider {
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: Anchor,
debounce: bool,
cx: &mut Context<Self>,
) {
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
// Get API settings from the global Ollama service or fallback
let (http_client, api_url) = if let Some(service) = OllamaService::global(cx) {
let service_ref = service.read(cx);
(service_ref.http_client.clone(), service_ref.api_url.clone())
} else {
// Fallback if global service isn't available
(
project
.as_ref()
.map(|p| p.read(cx).client().http_client() as Arc<dyn HttpClient>)
.unwrap_or_else(|| {
Arc::new(http_client::BlockedHttpClient::new()) as Arc<dyn HttpClient>
}),
crate::OLLAMA_API_URL.to_string(),
)
};
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
if debounce {
@ -156,14 +299,17 @@ impl EditPredictionProvider for OllamaCompletionProvider {
let response = generate(http_client.as_ref(), &api_url, api_key, request)
.await
.context("Failed to get completion from Ollama")?;
.context("Failed to get completion from Ollama");
this.update(cx, |this, cx| {
this.pending_refresh = None;
if !response.response.trim().is_empty() {
this.current_completion = Some(response.response);
} else {
this.current_completion = None;
match response {
Ok(response) if !response.response.trim().is_empty() => {
this.current_completion = Some(response.response);
}
_ => {
this.current_completion = None;
}
}
cx.notify();
})?;
@ -248,7 +394,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
#[cfg(test)]
mod tests {
use super::*;
use crate::fake::Ollama;
use gpui::{AppContext, TestAppContext};
@ -269,31 +414,238 @@ mod tests {
}
/// Test the complete Ollama completion flow from refresh to suggestion
#[test]
fn test_get_stop_tokens() {
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
#[gpui::test]
fn test_get_stop_tokens(cx: &mut TestAppContext) {
init_test(cx);
// Test CodeLlama code model gets stop tokens
let codellama_provider = OllamaCompletionProvider::new(
http_client.clone(),
"http://localhost:11434".to_string(),
"codellama:7b-code".to_string(),
None,
);
let codellama_provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx))
});
assert_eq!(
codellama_provider.get_stop_tokens(),
Some(vec!["<EOT>".to_string()])
);
codellama_provider.read_with(cx, |provider, _| {
assert_eq!(provider.get_stop_tokens(), Some(vec!["<EOT>".to_string()]));
});
// Test non-CodeLlama model doesn't get stop tokens
let qwen_provider = OllamaCompletionProvider::new(
http_client.clone(),
"http://localhost:11434".to_string(),
"qwen2.5-coder:3b".to_string(),
None,
);
assert_eq!(qwen_provider.get_stop_tokens(), None);
let qwen_provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
qwen_provider.read_with(cx, |provider, _| {
assert_eq!(provider.get_stop_tokens(), None);
});
}
#[gpui::test]
async fn test_model_discovery(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Mock /api/tags response (list models)
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:3b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "3B",
"quantization_level": "Q4_0"
}
},
{
"name": "codellama:7b-code",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "codellama",
"families": ["codellama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "nomic-embed-text",
"modified_at": "2024-01-01T00:00:00Z",
"size": 500000,
"digest": "ghi789",
"details": {
"format": "gguf",
"family": "nomic-embed",
"families": ["nomic-embed"],
"parameter_size": "137M",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show responses for model capabilities
let qwen_capabilities = serde_json::json!({
"capabilities": ["tools", "thinking"]
});
let _codellama_capabilities = serde_json::json!({
"capabilities": []
});
fake_http_client.set_response("/api/show", qwen_capabilities.to_string());
// Create global Ollama service for testing
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
// Set it as global
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create completion provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Wait for model discovery to complete
cx.background_executor.run_until_parked();
// Verify models were discovered through the global provider
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 2); // Should exclude nomic-embed-text
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"codellama:7b-code"));
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(!model_names.contains(&"nomic-embed-text"));
});
}
#[gpui::test]
async fn test_model_discovery_api_failure(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client that returns errors
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_error("Connection refused");
// Create global Ollama service that will fail
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create completion provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Wait for model discovery to complete (with failure)
cx.background_executor.run_until_parked();
// Verify graceful handling - should have empty model list
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 0);
});
}
#[gpui::test]
async fn test_refresh_models(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Initially return empty model list
let empty_response = serde_json::json!({"models": []});
fake_http_client.set_response("/api/tags", empty_response.to_string());
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx))
});
cx.background_executor.run_until_parked();
// Verify initially empty
provider.read_with(cx, |provider, cx| {
assert_eq!(provider.available_models(cx).len(), 0);
});
// Update mock to return models
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
let capabilities = serde_json::json!({
"capabilities": ["tools", "thinking"]
});
fake_http_client.set_response("/api/show", capabilities.to_string());
// Trigger refresh
provider.update(cx, |provider, cx| {
provider.refresh_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "qwen2.5-coder:7b");
});
}
#[gpui::test]
@ -306,12 +658,28 @@ mod tests {
buffer.anchor_before(11) // Position in the middle of the function
});
// Create Ollama provider with fake HTTP client
let (provider, fake_http_client) = Ollama::fake(cx);
// Configure mock HTTP response
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("println!(\"Hello\");");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Trigger completion refresh (no debounce for test speed)
provider.update(cx, |provider, cx| {
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
@ -363,7 +731,26 @@ mod tests {
buffer.anchor_after(16) // After "vec"
});
let (provider, fake_http_client) = Ollama::fake(cx);
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Configure response that starts with what user already typed
fake_http_client.set_generate_response("vec![1, 2, 3]");
@ -393,7 +780,28 @@ mod tests {
init_test(cx);
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
let (provider, fake_http_client) = Ollama::fake(cx);
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("vec![hello, world]");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Set up the editor with the Ollama provider
editor_cx.update_editor(|editor, window, cx| {
@ -403,9 +811,6 @@ mod tests {
// Set initial state
editor_cx.set_state("let items = ˇ");
// Configure a multi-word completion
fake_http_client.set_generate_response("vec![hello, world]");
// Trigger the completion through the provider
let buffer =
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
@ -455,7 +860,28 @@ mod tests {
init_test(cx);
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
let (provider, fake_http_client) = Ollama::fake(cx);
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("bar");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Set up the editor with the Ollama provider
editor_cx.update_editor(|editor, window, cx| {
@ -464,9 +890,6 @@ mod tests {
editor_cx.set_state("fooˇ");
// Configure completion response that extends the current text
fake_http_client.set_generate_response("bar");
// Trigger the completion through the provider
let buffer =
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());