Auto detect models WIP

This commit is contained in:
Oliver Azevedo Barnes 2025-07-25 10:21:32 +01:00
parent 5a1506c3c2
commit 0bdb42e65d
No known key found for this signature in database
8 changed files with 952 additions and 128 deletions

3
Cargo.lock generated
View file

@ -8363,6 +8363,7 @@ dependencies = [
"fs",
"futures 0.3.31",
"gpui",
"gpui_tokio",
"http_client",
"indoc",
"inline_completion",
@ -8370,6 +8371,7 @@ dependencies = [
"language_model",
"language_models",
"lsp",
"ollama",
"paths",
"project",
"regex",
@ -20253,6 +20255,7 @@ dependencies = [
"nix 0.29.0",
"node_runtime",
"notifications",
"ollama",
"onboarding",
"outline",
"outline_panel",

View file

@ -25,6 +25,7 @@ indoc.workspace = true
inline_completion.workspace = true
language.workspace = true
language_models.workspace = true
ollama.workspace = true
paths.workspace = true
regex.workspace = true
@ -48,6 +49,9 @@ http_client = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language_model = { workspace = true, features = ["test-support"] }
lsp = { workspace = true, features = ["test-support"] }
ollama = { workspace = true, features = ["test-support"] }
project = { workspace = true, features = ["test-support"] }
serde_json.workspace = true
settings = { workspace = true, features = ["test-support"] }
theme = { workspace = true, features = ["test-support"] }
gpui_tokio.workspace = true

View file

@ -21,6 +21,8 @@ use language::{
};
use language_models::AllLanguageModelSettings;
use ollama;
use paths;
use regex::Regex;
use settings::{Settings, SettingsStore, update_settings_file};
@ -413,6 +415,10 @@ impl InlineCompletionButton {
cx.observe_global::<SettingsStore>(move |_, cx| cx.notify())
.detach();
if let Some(service) = ollama::OllamaService::global(cx) {
cx.observe(&service, |_, _, cx| cx.notify()).detach();
}
Self {
editor_subscription: None,
editor_enabled: None,
@ -858,8 +864,30 @@ impl InlineCompletionButton {
let settings = AllLanguageModelSettings::get_global(cx);
let ollama_settings = &settings.ollama;
// Clone needed values to avoid borrowing issues
let available_models = ollama_settings.available_models.clone();
// Get models from both settings and global service discovery
let mut available_models = ollama_settings.available_models.clone();
// Add discovered models from the global Ollama service
if let Some(service) = ollama::OllamaService::global(cx) {
let discovered_models = service.read(cx).available_models();
for model in discovered_models {
// Convert from ollama::Model to language_models AvailableModel
let available_model = language_models::provider::ollama::AvailableModel {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: model.keep_alive.clone(),
supports_tools: model.supports_tools,
supports_images: model.supports_vision,
supports_thinking: model.supports_thinking,
};
// Add if not already in settings (settings take precedence)
if !available_models.iter().any(|m| m.name == model.name) {
available_models.push(available_model);
}
}
}
// API URL configuration - only show if Ollama settings exist in the user's config
let menu = if Self::ollama_settings_exist(cx) {
@ -878,7 +906,7 @@ impl InlineCompletionButton {
let menu = menu.separator().header("Available Models");
// Add each available model as a menu entry
available_models.iter().fold(menu, |menu, model| {
let menu = available_models.iter().fold(menu, |menu, model| {
let model_name = model.display_name.as_ref().unwrap_or(&model.name);
let is_current = available_models
.first()
@ -898,6 +926,13 @@ impl InlineCompletionButton {
}
},
)
});
// Add refresh models option
menu.separator().entry("Refresh Models", None, {
move |_window, cx| {
Self::refresh_ollama_models(cx);
}
})
} else {
menu.separator()
@ -908,6 +943,11 @@ impl InlineCompletionButton {
Self::open_ollama_settings(fs.clone(), window, cx);
}
})
.entry("Refresh Models", None, {
move |_window, cx| {
Self::refresh_ollama_models(cx);
}
})
};
// Use the common language settings menu
@ -997,6 +1037,14 @@ impl InlineCompletionButton {
});
}
fn refresh_ollama_models(cx: &mut App) {
if let Some(service) = ollama::OllamaService::global(cx) {
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
}
}
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
let editor = editor.read(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
@ -1188,3 +1236,359 @@ fn toggle_edit_prediction_mode(fs: Arc<dyn Fs>, mode: EditPredictionsMode, cx: &
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use clock::FakeSystemClock;
use gpui::TestAppContext;
use http_client;
use language_models::provider::ollama::AvailableModel;
use ollama::{OllamaService, fake::FakeHttpClient};
use settings::SettingsStore;
use std::sync::Arc;
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
gpui_tokio::init(cx);
theme::init(theme::LoadThemes::JustBase, cx);
language::init(cx);
language_settings::init(cx);
});
}
#[gpui::test]
async fn test_ollama_menu_shows_discovered_models(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client with mock models response
let fake_http_client = Arc::new(FakeHttpClient::new());
// Mock /api/tags response
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:3b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "3B",
"quantization_level": "Q4_0"
}
},
{
"name": "codellama:7b-code",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "codellama",
"families": ["codellama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show response
let capabilities = serde_json::json!({
"capabilities": ["tools"]
});
fake_http_client.set_response("/api/show", capabilities.to_string());
// Create and set global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Wait for model discovery
cx.background_executor.run_until_parked();
// Verify models are accessible through the service
cx.update(|cx| {
if let Some(service) = OllamaService::global(cx) {
let discovered_models = service.read(cx).available_models();
assert_eq!(discovered_models.len(), 2);
let model_names: Vec<&str> =
discovered_models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(model_names.contains(&"codellama:7b-code"));
} else {
panic!("Global service should be available");
}
});
// Verify the global service has the expected models
service.read_with(cx, |service, _| {
let models = service.available_models();
assert_eq!(models.len(), 2);
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(model_names.contains(&"codellama:7b-code"));
});
}
#[gpui::test]
async fn test_ollama_menu_shows_service_models(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client with models
let fake_http_client = Arc::new(FakeHttpClient::new());
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
fake_http_client.set_response(
"/api/show",
serde_json::json!({"capabilities": []}).to_string(),
);
// Create and set global service
let service = cx.update(|cx| {
OllamaService::new(fake_http_client, "http://localhost:11434".to_string(), cx)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
cx.background_executor.run_until_parked();
// Test that discovered models are accessible
cx.update(|cx| {
if let Some(service) = OllamaService::global(cx) {
let discovered_models = service.read(cx).available_models();
assert_eq!(discovered_models.len(), 1);
assert_eq!(discovered_models[0].name, "qwen2.5-coder:7b");
} else {
panic!("Global service should be available");
}
});
}
#[gpui::test]
async fn test_ollama_menu_refreshes_on_service_update(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(FakeHttpClient::new());
// Initially empty models
fake_http_client.set_response("/api/tags", serde_json::json!({"models": []}).to_string());
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
cx.background_executor.run_until_parked();
// Verify the service subscription mechanism works by creating a button
let _button = cx.update(|cx| {
let fs = fs::FakeFs::new(cx.background_executor().clone());
let user_store = cx.new(|cx| {
client::UserStore::new(
Arc::new(http_client::FakeHttpClient::create(|_| {
Box::pin(async { Err(anyhow::anyhow!("not implemented")) })
})),
cx,
)
});
let popover_handle = PopoverMenuHandle::default();
cx.new(|cx| InlineCompletionButton::new(fs, user_store, popover_handle, cx))
});
// Verify initially no models
service.read_with(cx, |service, _| {
assert_eq!(service.available_models().len(), 0);
});
// Update mock to return models
let models_response = serde_json::json!({
"models": [
{
"name": "phi3:mini",
"modified_at": "2024-01-01T00:00:00Z",
"size": 500000,
"digest": "xyz789",
"details": {
"format": "gguf",
"family": "phi3",
"families": ["phi3"],
"parameter_size": "3.8B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
fake_http_client.set_response(
"/api/show",
serde_json::json!({"capabilities": []}).to_string(),
);
// Trigger refresh
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
service.read_with(cx, |service, _| {
let models = service.available_models();
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "phi3:mini");
});
// The button should have been notified and will rebuild its menu with new models
// when next requested (this tests the subscription mechanism)
}
#[gpui::test]
async fn test_refresh_models_button_functionality(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(FakeHttpClient::new());
// Start with one model
let initial_response = serde_json::json!({
"models": [
{
"name": "mistral:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "initial123",
"details": {
"format": "gguf",
"family": "mistral",
"families": ["mistral"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", initial_response.to_string());
fake_http_client.set_response(
"/api/show",
serde_json::json!({"capabilities": []}).to_string(),
);
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
cx.background_executor.run_until_parked();
// Verify initial model
service.read_with(cx, |service, _| {
assert_eq!(service.available_models().len(), 1);
assert_eq!(service.available_models()[0].name, "mistral:7b");
});
// Update mock to simulate new model available
let updated_response = serde_json::json!({
"models": [
{
"name": "mistral:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "initial123",
"details": {
"format": "gguf",
"family": "mistral",
"families": ["mistral"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "gemma2:9b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "new456",
"details": {
"format": "gguf",
"family": "gemma2",
"families": ["gemma2"],
"parameter_size": "9B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", updated_response.to_string());
// Simulate clicking "Refresh Models" button
cx.update(|cx| {
InlineCompletionButton::refresh_ollama_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
service.read_with(cx, |service, _| {
let models = service.available_models();
assert_eq!(models.len(), 2);
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"mistral:7b"));
assert!(model_names.contains(&"gemma2:9b"));
});
}
}

View file

@ -1,7 +1,7 @@
use anyhow::{Result, anyhow};
use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
use futures::{Stream, TryFutureExt, stream};
use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
use gpui::{AnyView, App, AsyncApp, Context, Entity, Global, Subscription, Task};
use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
@ -141,6 +141,29 @@ impl State {
}
impl OllamaLanguageModelProvider {
pub fn global(cx: &App) -> Option<Entity<Self>> {
cx.try_global::<GlobalOllamaLanguageModelProvider>()
.map(|provider| provider.0.clone())
}
pub fn set_global(provider: Entity<Self>, cx: &mut App) {
cx.set_global(GlobalOllamaLanguageModelProvider(provider));
}
pub fn available_models_for_completion(&self, cx: &App) -> Vec<ollama::Model> {
self.state.read(cx).available_models.clone()
}
pub fn http_client(&self) -> Arc<dyn HttpClient> {
self.http_client.clone()
}
pub fn refresh_models(&self, cx: &mut App) {
self.state.update(cx, |state, cx| {
state.restart_fetch_models_task(cx);
});
}
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
let this = Self {
http_client: http_client.clone(),
@ -667,6 +690,10 @@ impl Render for ConfigurationView {
}
}
struct GlobalOllamaLanguageModelProvider(Entity<OllamaLanguageModelProvider>);
impl Global for GlobalOllamaLanguageModelProvider {}
fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
ollama::OllamaTool::Function {
function: OllamaFunctionTool {

View file

@ -30,6 +30,7 @@ language.workspace = true
log.workspace = true
project.workspace = true
settings.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true

View file

@ -541,14 +541,8 @@ pub mod fake {
) {
let fake_client = std::sync::Arc::new(FakeHttpClient::new());
let provider = cx.new(|_| {
OllamaCompletionProvider::new(
fake_client.clone(),
"http://localhost:11434".to_string(),
"qwencoder".to_string(),
None,
)
});
let provider =
cx.new(|cx| OllamaCompletionProvider::new("qwencoder".to_string(), None, cx));
(provider, fake_client)
}

View file

@ -1,43 +1,172 @@
use crate::{GenerateOptions, GenerateRequest, generate};
use crate::{GenerateOptions, GenerateRequest, Model, generate};
use anyhow::{Context as AnyhowContext, Result};
use futures::StreamExt;
use std::{path::Path, sync::Arc, time::Duration};
use gpui::{App, Context, Entity, EntityId, Task};
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task};
use http_client::HttpClient;
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
use language::{Anchor, Buffer, ToOffset};
use settings::SettingsStore;
use project::Project;
use std::{path::Path, sync::Arc, time::Duration};
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
pub struct OllamaCompletionProvider {
// Global Ollama service for managing models across all providers
pub struct OllamaService {
http_client: Arc<dyn HttpClient>,
api_url: String,
available_models: Vec<Model>,
fetch_models_task: Option<Task<Result<()>>>,
_settings_subscription: Subscription,
}
impl OllamaService {
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, cx: &mut App) -> Entity<Self> {
cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
move |this: &mut OllamaService, cx| {
this.restart_fetch_models_task(cx);
}
});
let mut service = Self {
http_client,
api_url,
available_models: Vec::new(),
fetch_models_task: None,
_settings_subscription: subscription,
};
service.restart_fetch_models_task(cx);
service
})
}
pub fn global(cx: &App) -> Option<Entity<Self>> {
cx.try_global::<GlobalOllamaService>()
.map(|service| service.0.clone())
}
pub fn set_global(service: Entity<Self>, cx: &mut App) {
cx.set_global(GlobalOllamaService(service));
}
pub fn available_models(&self) -> &[Model] {
&self.available_models
}
pub fn refresh_models(&mut self, cx: &mut Context<Self>) {
self.restart_fetch_models_task(cx);
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
self.fetch_models_task = Some(self.fetch_models(cx));
}
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let http_client = Arc::clone(&self.http_client);
let api_url = self.api_url.clone();
cx.spawn(async move |this, cx| {
let models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
Ok(models) => models,
Err(_) => return Ok(()), // Silently fail and use empty list
};
let tasks = models
.into_iter()
// Filter out embedding models
.filter(|model| !model.name.contains("-embed"))
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
async move {
let name = model.name.as_str();
let capabilities =
crate::show_model(http_client.as_ref(), &api_url, name).await?;
let ollama_model = Model::new(
name,
None,
None,
Some(capabilities.supports_tools()),
Some(capabilities.supports_vision()),
Some(capabilities.supports_thinking()),
);
Ok(ollama_model)
}
});
// Rate-limit capability fetches
let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
.buffer_unordered(5)
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()
.unwrap_or_default();
ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(cx, |this, cx| {
this.available_models = ollama_models;
cx.notify();
})?;
Ok(())
})
}
}
struct GlobalOllamaService(Entity<OllamaService>);
impl Global for GlobalOllamaService {}
pub struct OllamaCompletionProvider {
model: String,
buffer_id: Option<EntityId>,
file_extension: Option<String>,
current_completion: Option<String>,
pending_refresh: Option<Task<Result<()>>>,
api_key: Option<String>,
_service_subscription: Option<Subscription>,
}
impl OllamaCompletionProvider {
pub fn new(
http_client: Arc<dyn HttpClient>,
api_url: String,
model: String,
api_key: Option<String>,
) -> Self {
pub fn new(model: String, api_key: Option<String>, cx: &mut Context<Self>) -> Self {
let subscription = if let Some(service) = OllamaService::global(cx) {
Some(cx.observe(&service, |_this, _service, cx| {
cx.notify();
}))
} else {
None
};
Self {
http_client,
api_url,
model,
buffer_id: None,
file_extension: None,
current_completion: None,
pending_refresh: None,
api_key,
_service_subscription: subscription,
}
}
pub fn available_models(&self, cx: &App) -> Vec<Model> {
if let Some(service) = OllamaService::global(cx) {
service.read(cx).available_models().to_vec()
} else {
Vec::new()
}
}
pub fn refresh_models(&self, cx: &mut App) {
if let Some(service) = OllamaService::global(cx) {
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
}
}
@ -104,14 +233,28 @@ impl EditPredictionProvider for OllamaCompletionProvider {
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: Anchor,
debounce: bool,
cx: &mut Context<Self>,
) {
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
// Get API settings from the global Ollama service or fallback
let (http_client, api_url) = if let Some(service) = OllamaService::global(cx) {
let service_ref = service.read(cx);
(service_ref.http_client.clone(), service_ref.api_url.clone())
} else {
// Fallback if global service isn't available
(
project
.as_ref()
.map(|p| p.read(cx).client().http_client() as Arc<dyn HttpClient>)
.unwrap_or_else(|| {
Arc::new(http_client::BlockedHttpClient::new()) as Arc<dyn HttpClient>
}),
crate::OLLAMA_API_URL.to_string(),
)
};
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
if debounce {
@ -156,14 +299,17 @@ impl EditPredictionProvider for OllamaCompletionProvider {
let response = generate(http_client.as_ref(), &api_url, api_key, request)
.await
.context("Failed to get completion from Ollama")?;
.context("Failed to get completion from Ollama");
this.update(cx, |this, cx| {
this.pending_refresh = None;
if !response.response.trim().is_empty() {
this.current_completion = Some(response.response);
} else {
this.current_completion = None;
match response {
Ok(response) if !response.response.trim().is_empty() => {
this.current_completion = Some(response.response);
}
_ => {
this.current_completion = None;
}
}
cx.notify();
})?;
@ -248,7 +394,6 @@ impl EditPredictionProvider for OllamaCompletionProvider {
#[cfg(test)]
mod tests {
use super::*;
use crate::fake::Ollama;
use gpui::{AppContext, TestAppContext};
@ -269,31 +414,238 @@ mod tests {
}
/// Test the complete Ollama completion flow from refresh to suggestion
#[test]
fn test_get_stop_tokens() {
let http_client = Arc::new(crate::fake::FakeHttpClient::new());
#[gpui::test]
fn test_get_stop_tokens(cx: &mut TestAppContext) {
init_test(cx);
// Test CodeLlama code model gets stop tokens
let codellama_provider = OllamaCompletionProvider::new(
http_client.clone(),
"http://localhost:11434".to_string(),
"codellama:7b-code".to_string(),
None,
);
let codellama_provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx))
});
assert_eq!(
codellama_provider.get_stop_tokens(),
Some(vec!["<EOT>".to_string()])
);
codellama_provider.read_with(cx, |provider, _| {
assert_eq!(provider.get_stop_tokens(), Some(vec!["<EOT>".to_string()]));
});
// Test non-CodeLlama model doesn't get stop tokens
let qwen_provider = OllamaCompletionProvider::new(
http_client.clone(),
"http://localhost:11434".to_string(),
"qwen2.5-coder:3b".to_string(),
None,
);
assert_eq!(qwen_provider.get_stop_tokens(), None);
let qwen_provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
qwen_provider.read_with(cx, |provider, _| {
assert_eq!(provider.get_stop_tokens(), None);
});
}
#[gpui::test]
async fn test_model_discovery(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Mock /api/tags response (list models)
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:3b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "3B",
"quantization_level": "Q4_0"
}
},
{
"name": "codellama:7b-code",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "codellama",
"families": ["codellama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "nomic-embed-text",
"modified_at": "2024-01-01T00:00:00Z",
"size": 500000,
"digest": "ghi789",
"details": {
"format": "gguf",
"family": "nomic-embed",
"families": ["nomic-embed"],
"parameter_size": "137M",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show responses for model capabilities
let qwen_capabilities = serde_json::json!({
"capabilities": ["tools", "thinking"]
});
let _codellama_capabilities = serde_json::json!({
"capabilities": []
});
fake_http_client.set_response("/api/show", qwen_capabilities.to_string());
// Create global Ollama service for testing
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
// Set it as global
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create completion provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Wait for model discovery to complete
cx.background_executor.run_until_parked();
// Verify models were discovered through the global provider
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 2); // Should exclude nomic-embed-text
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"codellama:7b-code"));
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(!model_names.contains(&"nomic-embed-text"));
});
}
#[gpui::test]
async fn test_model_discovery_api_failure(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client that returns errors
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_error("Connection refused");
// Create global Ollama service that will fail
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create completion provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Wait for model discovery to complete (with failure)
cx.background_executor.run_until_parked();
// Verify graceful handling - should have empty model list
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 0);
});
}
#[gpui::test]
async fn test_refresh_models(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Initially return empty model list
let empty_response = serde_json::json!({"models": []});
fake_http_client.set_response("/api/tags", empty_response.to_string());
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx))
});
cx.background_executor.run_until_parked();
// Verify initially empty
provider.read_with(cx, |provider, cx| {
assert_eq!(provider.available_models(cx).len(), 0);
});
// Update mock to return models
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
let capabilities = serde_json::json!({
"capabilities": ["tools", "thinking"]
});
fake_http_client.set_response("/api/show", capabilities.to_string());
// Trigger refresh
provider.update(cx, |provider, cx| {
provider.refresh_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "qwen2.5-coder:7b");
});
}
#[gpui::test]
@ -306,12 +658,28 @@ mod tests {
buffer.anchor_before(11) // Position in the middle of the function
});
// Create Ollama provider with fake HTTP client
let (provider, fake_http_client) = Ollama::fake(cx);
// Configure mock HTTP response
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("println!(\"Hello\");");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Trigger completion refresh (no debounce for test speed)
provider.update(cx, |provider, cx| {
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
@ -363,7 +731,26 @@ mod tests {
buffer.anchor_after(16) // After "vec"
});
let (provider, fake_http_client) = Ollama::fake(cx);
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Configure response that starts with what user already typed
fake_http_client.set_generate_response("vec![1, 2, 3]");
@ -393,7 +780,28 @@ mod tests {
init_test(cx);
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
let (provider, fake_http_client) = Ollama::fake(cx);
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("vec![hello, world]");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Set up the editor with the Ollama provider
editor_cx.update_editor(|editor, window, cx| {
@ -403,9 +811,6 @@ mod tests {
// Set initial state
editor_cx.set_state("let items = ˇ");
// Configure a multi-word completion
fake_http_client.set_generate_response("vec![hello, world]");
// Trigger the completion through the provider
let buffer =
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
@ -455,7 +860,28 @@ mod tests {
init_test(cx);
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
let (provider, fake_http_client) = Ollama::fake(cx);
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("bar");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Set up the editor with the Ollama provider
editor_cx.update_editor(|editor, window, cx| {
@ -464,9 +890,6 @@ mod tests {
editor_cx.set_state("fooˇ");
// Configure completion response that extends the current text
fake_http_client.set_generate_response("bar");
// Trigger the completion through the provider
let buffer =
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());

View file

@ -6,7 +6,7 @@ use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use language_models::AllLanguageModelSettings;
use ollama::OllamaCompletionProvider;
use ollama::{OllamaCompletionProvider, OllamaService};
use settings::{Settings, SettingsStore};
use smol::stream::StreamExt;
use std::{cell::RefCell, rc::Rc, sync::Arc};
@ -18,6 +18,11 @@ use zed_actions;
use zeta::{ProviderDataCollection, ZetaInlineCompletionProvider};
pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
// Initialize global Ollama service
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let ollama_service = OllamaService::new(client.http_client(), settings.api_url.clone(), cx);
OllamaService::set_global(ollama_service, cx);
let editors: Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>> = Rc::default();
cx.observe_new({
let editors = editors.clone();
@ -138,8 +143,13 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
}
} else if provider == EditPredictionProvider::Ollama {
// Update Ollama providers when settings change but provider stays the same
update_ollama_providers(&editors, &client, user_store.clone(), cx);
// Update global Ollama service when settings change
let _settings = &AllLanguageModelSettings::get_global(cx).ollama;
if let Some(service) = OllamaService::global(cx) {
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
}
}
}
})
@ -152,46 +162,6 @@ fn clear_zeta_edit_history(_: &zeta::ClearHistory, cx: &mut App) {
}
}
fn update_ollama_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
client: &Arc<Client>,
user_store: Entity<UserStore>,
cx: &mut App,
) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let _current_model = settings
.available_models
.first()
.map(|m| m.name.clone())
.unwrap_or_else(|| "codellama:7b".to_string());
for (editor, window) in editors.borrow().iter() {
_ = window.update(cx, |_window, window, cx| {
_ = editor.update(cx, |editor, cx| {
if let Some(provider) = editor.edit_prediction_provider() {
// Check if this is an Ollama provider by comparing names
if provider.name() == "ollama" {
// Recreate the provider with the new model
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let _api_url = settings.api_url.clone();
// Get client from the registry context (need to pass it)
// For now, we'll trigger a full reassignment
assign_edit_prediction_provider(
editor,
EditPredictionProvider::Ollama,
&client,
user_store.clone(),
window,
cx,
);
}
}
})
});
}
}
fn assign_edit_prediction_providers(
editors: &Rc<RefCell<HashMap<WeakEntity<Editor>, AnyWindowHandle>>>,
provider: EditPredictionProvider,
@ -333,27 +303,25 @@ fn assign_edit_prediction_provider(
}
EditPredictionProvider::Ollama => {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let api_key = std::env::var("OLLAMA_API_KEY").ok();
// Only create provider if models are configured
// Note: Only FIM-capable models work with inline completion:
// ✓ Supported: qwen2.5-coder:*, starcoder2:*, codeqwen:*
// ✗ Not supported: codellama:*, deepseek-coder:*, llama3:*
if let Some(first_model) = settings.available_models.first() {
let api_url = settings.api_url.clone();
let model = first_model.name.clone();
// Get API key from environment variable only (credentials would require async handling)
let api_key = std::env::var("OLLAMA_API_KEY").ok();
let provider = cx.new(|_| {
OllamaCompletionProvider::new(client.http_client(), api_url, model, api_key)
});
editor.set_edit_prediction_provider(Some(provider), window, cx);
// Get model from settings or use discovered models
let model = if let Some(first_model) = settings.available_models.first() {
first_model.name.clone()
} else if let Some(service) = OllamaService::global(cx) {
// Use first discovered model
service
.read(cx)
.available_models()
.first()
.map(|m| m.name.clone())
.unwrap_or_else(|| "qwen2.5-coder:3b".to_string())
} else {
// No models configured - don't create a provider
// User will see "Configure Models" option in the completion menu
editor.set_edit_prediction_provider::<OllamaCompletionProvider>(None, window, cx);
}
"qwen2.5-coder:3b".to_string()
};
let provider = cx.new(|cx| OllamaCompletionProvider::new(model, api_key, cx));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}