ZIm/crates/ollama/src/ollama_completion_provider.rs
Oliver Azevedo Barnes 947781bc48
Merge models in local settings with ones listed by ollama
This allows for the scenario where the user doesn't have access to ollama's listing and needs to tell Zed explicitly, by hand
2025-07-31 13:42:53 +01:00

1106 lines
38 KiB
Rust

use crate::{GenerateOptions, GenerateRequest, Model, generate};
use anyhow::{Context as AnyhowContext, Result};
use futures::StreamExt;
use std::{path::Path, sync::Arc, time::Duration};
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task};
use http_client::HttpClient;
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
use language::{Anchor, Buffer, ToOffset};
use settings::SettingsStore;
use project::Project;
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
// Structure for passing settings model data without circular dependencies
#[derive(Clone, Debug)]
pub struct SettingsModel {
pub name: String,
pub display_name: Option<String>,
pub max_tokens: u64,
pub supports_tools: Option<bool>,
pub supports_images: Option<bool>,
pub supports_thinking: Option<bool>,
}
impl SettingsModel {
pub fn to_model(&self) -> Model {
Model::new(
&self.name,
self.display_name.as_deref(),
Some(self.max_tokens),
self.supports_tools,
self.supports_images,
self.supports_thinking,
)
}
}
// Global Ollama service for managing models across all providers
pub struct OllamaService {
http_client: Arc<dyn HttpClient>,
api_url: String,
available_models: Vec<Model>,
fetch_models_task: Option<Task<Result<()>>>,
_settings_subscription: Subscription,
}
impl OllamaService {
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, cx: &mut App) -> Entity<Self> {
cx.new(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
move |this: &mut OllamaService, cx| {
this.restart_fetch_models_task(cx);
}
});
let mut service = Self {
http_client,
api_url,
available_models: Vec::new(),
fetch_models_task: None,
_settings_subscription: subscription,
};
service.restart_fetch_models_task(cx);
service
})
}
pub fn global(cx: &App) -> Option<Entity<Self>> {
cx.try_global::<GlobalOllamaService>()
.map(|service| service.0.clone())
}
pub fn set_global(service: Entity<Self>, cx: &mut App) {
cx.set_global(GlobalOllamaService(service));
}
pub fn available_models(&self) -> &[Model] {
&self.available_models
}
pub fn refresh_models(&mut self, cx: &mut Context<Self>) {
self.restart_fetch_models_task(cx);
}
pub fn set_settings_models(
&mut self,
settings_models: Vec<SettingsModel>,
cx: &mut Context<Self>,
) {
// Convert settings models to our Model type
self.available_models = settings_models
.into_iter()
.map(|settings_model| settings_model.to_model())
.collect();
self.restart_fetch_models_task(cx);
}
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
self.fetch_models_task = Some(self.fetch_models(cx));
}
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
let http_client = Arc::clone(&self.http_client);
let api_url = self.api_url.clone();
cx.spawn(async move |this, cx| {
// Get the current settings models to merge with API models
let settings_models = this.update(cx, |this, _cx| {
// Get just the names of models from settings to avoid duplicates
this.available_models
.iter()
.map(|m| m.name.clone())
.collect::<std::collections::HashSet<_>>()
})?;
// Fetch models from API
let api_models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
Ok(models) => models,
Err(_) => return Ok(()), // Silently fail if API is unavailable
};
let tasks = api_models
.into_iter()
// Filter out embedding models
.filter(|model| !model.name.contains("-embed"))
// Filter out models that are already defined in settings
.filter(|model| !settings_models.contains(&model.name))
.map(|model| {
let http_client = Arc::clone(&http_client);
let api_url = api_url.clone();
async move {
let name = model.name.as_str();
let capabilities =
crate::show_model(http_client.as_ref(), &api_url, name).await?;
let ollama_model = Model::new(
name,
None,
None,
Some(capabilities.supports_tools()),
Some(capabilities.supports_vision()),
Some(capabilities.supports_thinking()),
);
Ok(ollama_model)
}
});
// Rate-limit capability fetches for API-discovered models
let api_discovered_models: Vec<_> = futures::stream::iter(tasks)
.buffer_unordered(5)
.collect::<Vec<Result<_>>>()
.await
.into_iter()
.collect::<Result<Vec<_>>>()
.unwrap_or_default();
this.update(cx, |this, cx| {
// Append API-discovered models to existing settings models
this.available_models.extend(api_discovered_models);
// Sort all models by name
this.available_models.sort_by(|a, b| a.name.cmp(&b.name));
cx.notify();
})?;
Ok(())
})
}
}
struct GlobalOllamaService(Entity<OllamaService>);
impl Global for GlobalOllamaService {}
pub struct OllamaCompletionProvider {
model: String,
buffer_id: Option<EntityId>,
file_extension: Option<String>,
current_completion: Option<String>,
pending_refresh: Option<Task<Result<()>>>,
api_key: Option<String>,
_service_subscription: Option<Subscription>,
}
impl OllamaCompletionProvider {
pub fn new(model: String, api_key: Option<String>, cx: &mut Context<Self>) -> Self {
let subscription = if let Some(service) = OllamaService::global(cx) {
Some(cx.observe(&service, |_this, _service, cx| {
cx.notify();
}))
} else {
None
};
Self {
model,
buffer_id: None,
file_extension: None,
current_completion: None,
pending_refresh: None,
api_key,
_service_subscription: subscription,
}
}
pub fn available_models(&self, cx: &App) -> Vec<Model> {
if let Some(service) = OllamaService::global(cx) {
service.read(cx).available_models().to_vec()
} else {
Vec::new()
}
}
pub fn refresh_models(&self, cx: &mut App) {
if let Some(service) = OllamaService::global(cx) {
service.update(cx, |service, cx| {
service.refresh_models(cx);
});
}
}
/// Updates the model used by this provider
pub fn update_model(&mut self, model: String) {
self.model = model;
}
/// Updates the file extension used by this provider
pub fn update_file_extension(&mut self, new_file_extension: String) {
self.file_extension = Some(new_file_extension);
}
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
let cursor_offset = cursor_position.to_offset(buffer);
let text = buffer.text();
// Get reasonable context around cursor
let context_size = 2000; // 2KB before and after cursor
let start = cursor_offset.saturating_sub(context_size);
let end = (cursor_offset + context_size).min(text.len());
let prefix = text[start..cursor_offset].to_string();
let suffix = text[cursor_offset..end].to_string();
(prefix, suffix)
}
/// Get stop tokens for the current model
/// For now we only handle the case for codellama:7b-code model
/// that we found was including the stop token in the completion suggestion.
/// We wanted to avoid going down this route and let Ollama abstract all template tokens away.
/// But apparently, and surprisingly for a llama model, Ollama misses this case.
fn get_stop_tokens(&self) -> Option<Vec<String>> {
if self.model.contains("codellama") && self.model.contains("code") {
Some(vec!["<EOT>".to_string()])
} else {
None
}
}
}
impl EditPredictionProvider for OllamaCompletionProvider {
fn name() -> &'static str {
"ollama"
}
fn display_name() -> &'static str {
"Ollama"
}
fn show_completions_in_menu() -> bool {
true
}
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
true
}
fn is_refreshing(&self) -> bool {
self.pending_refresh.is_some()
}
fn refresh(
&mut self,
project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: Anchor,
debounce: bool,
cx: &mut Context<Self>,
) {
// Get API settings from the global Ollama service or fallback
let (http_client, api_url) = if let Some(service) = OllamaService::global(cx) {
let service_ref = service.read(cx);
(service_ref.http_client.clone(), service_ref.api_url.clone())
} else {
// Fallback if global service isn't available
(
project
.as_ref()
.map(|p| p.read(cx).client().http_client() as Arc<dyn HttpClient>)
.unwrap_or_else(|| {
Arc::new(http_client::BlockedHttpClient::new()) as Arc<dyn HttpClient>
}),
crate::OLLAMA_API_URL.to_string(),
)
};
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
if debounce {
cx.background_executor()
.timer(OLLAMA_DEBOUNCE_TIMEOUT)
.await;
}
let (prefix, suffix) = this.update(cx, |this, cx| {
let buffer_snapshot = buffer.read(cx);
this.buffer_id = Some(buffer.entity_id());
this.file_extension = buffer_snapshot.file().and_then(|file| {
Some(
Path::new(file.file_name(cx))
.extension()?
.to_str()?
.to_string(),
)
});
this.extract_context(buffer_snapshot, cursor_position)
})?;
let (model, api_key) =
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?;
let request = GenerateRequest {
model,
prompt: prefix,
suffix: Some(suffix),
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150), // Reasonable completion length
temperature: Some(0.1), // Low temperature for more deterministic results
top_p: Some(0.95),
stop: stop_tokens,
}),
keep_alive: None,
context: None,
};
let response = generate(http_client.as_ref(), &api_url, api_key, request)
.await
.context("Failed to get completion from Ollama");
this.update(cx, |this, cx| {
this.pending_refresh = None;
match response {
Ok(response) if !response.response.trim().is_empty() => {
this.current_completion = Some(response.response);
}
_ => {
this.current_completion = None;
}
}
cx.notify();
})?;
Ok(())
}));
}
fn cycle(
&mut self,
_buffer: Entity<Buffer>,
_cursor_position: Anchor,
_direction: Direction,
_cx: &mut Context<Self>,
) {
// Ollama doesn't provide multiple completions in a single request
// Could be implemented by making multiple requests with different temperatures
// or by using different models
}
fn accept(&mut self, _cx: &mut Context<Self>) {
self.current_completion = None;
// TODO: Could send accept telemetry to Ollama if supported
}
fn discard(&mut self, _cx: &mut Context<Self>) {
self.current_completion = None;
// TODO: Could send discard telemetry to Ollama if supported
}
fn suggest(
&mut self,
buffer: &Entity<Buffer>,
cursor_position: Anchor,
cx: &mut Context<Self>,
) -> Option<InlineCompletion> {
let buffer_id = buffer.entity_id();
if Some(buffer_id) != self.buffer_id {
return None;
}
let completion_text = self.current_completion.as_ref()?.clone();
if completion_text.trim().is_empty() {
return None;
}
let buffer_snapshot = buffer.read(cx);
let cursor_offset = cursor_position.to_offset(buffer_snapshot);
// Get text before cursor to check what's already been typed
let text_before_cursor = buffer_snapshot
.text_for_range(0..cursor_offset)
.collect::<String>();
// Find how much of the completion has already been typed by checking
// if the text before the cursor ends with a prefix of our completion
let mut prefix_len = 0;
for i in 1..=completion_text.len().min(text_before_cursor.len()) {
if text_before_cursor.ends_with(&completion_text[..i]) {
prefix_len = i;
}
}
// Only suggest the remaining part of the completion
let remaining_completion = &completion_text[prefix_len..];
if remaining_completion.trim().is_empty() {
return None;
}
let position = cursor_position.bias_right(buffer_snapshot);
Some(InlineCompletion {
id: None,
edits: vec![(position..position, remaining_completion.to_string())],
edit_preview: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{AppContext, TestAppContext};
use client;
use language::Buffer;
use project::Project;
use settings::SettingsStore;
fn init_test(cx: &mut TestAppContext) {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
theme::init(theme::LoadThemes::JustBase, cx);
client::init_settings(cx);
language::init(cx);
editor::init_settings(cx);
Project::init_settings(cx);
workspace::init_settings(cx);
});
}
/// Test the complete Ollama completion flow from refresh to suggestion
#[gpui::test]
fn test_get_stop_tokens(cx: &mut TestAppContext) {
init_test(cx);
// Test CodeLlama code model gets stop tokens
let codellama_provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx))
});
codellama_provider.read_with(cx, |provider, _| {
assert_eq!(provider.get_stop_tokens(), Some(vec!["<EOT>".to_string()]));
});
// Test non-CodeLlama model doesn't get stop tokens
let qwen_provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
qwen_provider.read_with(cx, |provider, _| {
assert_eq!(provider.get_stop_tokens(), None);
});
}
#[gpui::test]
async fn test_model_discovery(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Mock /api/tags response (list models)
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:3b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "3B",
"quantization_level": "Q4_0"
}
},
{
"name": "codellama:7b-code",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "codellama",
"families": ["codellama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "nomic-embed-text",
"modified_at": "2024-01-01T00:00:00Z",
"size": 500000,
"digest": "ghi789",
"details": {
"format": "gguf",
"family": "nomic-embed",
"families": ["nomic-embed"],
"parameter_size": "137M",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show responses for model capabilities
let qwen_capabilities = serde_json::json!({
"capabilities": ["tools", "thinking"]
});
let _codellama_capabilities = serde_json::json!({
"capabilities": []
});
fake_http_client.set_response("/api/show", qwen_capabilities.to_string());
// Create global Ollama service for testing
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
// Set it as global
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create completion provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Wait for model discovery to complete
cx.background_executor.run_until_parked();
// Verify models were discovered through the global provider
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 2); // Should exclude nomic-embed-text
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"codellama:7b-code"));
assert!(model_names.contains(&"qwen2.5-coder:3b"));
assert!(!model_names.contains(&"nomic-embed-text"));
});
}
#[gpui::test]
async fn test_model_discovery_api_failure(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client that returns errors
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_error("Connection refused");
// Create global Ollama service that will fail
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create completion provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Wait for model discovery to complete (with failure)
cx.background_executor.run_until_parked();
// Verify graceful handling - should have empty model list
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 0);
});
}
#[gpui::test]
async fn test_refresh_models(cx: &mut TestAppContext) {
init_test(cx);
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Initially return empty model list
let empty_response = serde_json::json!({"models": []});
fake_http_client.set_response("/api/tags", empty_response.to_string());
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx))
});
cx.background_executor.run_until_parked();
// Verify initially empty
provider.read_with(cx, |provider, cx| {
assert_eq!(provider.available_models(cx).len(), 0);
});
// Update mock to return models
let models_response = serde_json::json!({
"models": [
{
"name": "qwen2.5-coder:7b",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "qwen2",
"families": ["qwen2"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
let capabilities = serde_json::json!({
"capabilities": ["tools", "thinking"]
});
fake_http_client.set_response("/api/show", capabilities.to_string());
// Trigger refresh
provider.update(cx, |provider, cx| {
provider.refresh_models(cx);
});
cx.background_executor.run_until_parked();
// Verify models were refreshed
provider.read_with(cx, |provider, cx| {
let models = provider.available_models(cx);
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "qwen2.5-coder:7b");
});
}
#[gpui::test]
async fn test_full_completion_flow(cx: &mut TestAppContext) {
init_test(cx);
// Create a buffer with realistic code content
let buffer = cx.update(|cx| cx.new(|cx| Buffer::local("fn test() {\n \n}", cx)));
let cursor_position = buffer.read_with(cx, |buffer, _| {
buffer.anchor_before(11) // Position in the middle of the function
});
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("println!(\"Hello\");");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Trigger completion refresh (no debounce for test speed)
provider.update(cx, |provider, cx| {
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
});
// Wait for completion task to complete
cx.background_executor.run_until_parked();
// Verify completion was processed and stored
provider.read_with(cx, |provider, _cx| {
assert!(provider.current_completion.is_some());
assert_eq!(
provider.current_completion.as_ref().unwrap(),
"println!(\"Hello\");"
);
assert!(!provider.is_refreshing());
});
// Test suggestion logic returns the completion
let suggestion = cx.update(|cx| {
provider.update(cx, |provider, cx| {
provider.suggest(&buffer, cursor_position, cx)
})
});
assert!(suggestion.is_some());
let suggestion = suggestion.unwrap();
assert_eq!(suggestion.edits.len(), 1);
assert_eq!(suggestion.edits[0].1, "println!(\"Hello\");");
// Verify acceptance clears the completion
provider.update(cx, |provider, cx| {
provider.accept(cx);
});
provider.read_with(cx, |provider, _cx| {
assert!(provider.current_completion.is_none());
});
}
/// Test that partial typing is handled correctly - only suggests untyped portion
#[gpui::test]
async fn test_partial_typing_handling(cx: &mut TestAppContext) {
init_test(cx);
// Create buffer where user has partially typed "vec"
let buffer = cx.update(|cx| cx.new(|cx| Buffer::local("let result = vec", cx)));
let cursor_position = buffer.read_with(cx, |buffer, _| {
buffer.anchor_after(16) // After "vec"
});
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Configure response that starts with what user already typed
fake_http_client.set_generate_response("vec![1, 2, 3]");
provider.update(cx, |provider, cx| {
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
});
cx.background_executor.run_until_parked();
// Should suggest only the remaining part after "vec"
let suggestion = cx.update(|cx| {
provider.update(cx, |provider, cx| {
provider.suggest(&buffer, cursor_position, cx)
})
});
// Verify we get a reasonable suggestion
if let Some(suggestion) = suggestion {
assert_eq!(suggestion.edits.len(), 1);
assert!(suggestion.edits[0].1.contains("1, 2, 3"));
}
}
#[gpui::test]
async fn test_accept_partial_ollama_suggestion(cx: &mut TestAppContext) {
init_test(cx);
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("vec![hello, world]");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Set up the editor with the Ollama provider
editor_cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(provider.clone()), window, cx);
});
// Set initial state
editor_cx.set_state("let items = ˇ");
// Trigger the completion through the provider
let buffer =
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
let cursor_position = editor_cx.buffer_snapshot().anchor_after(12);
provider.update(cx, |provider, cx| {
provider.refresh(None, buffer, cursor_position, false, cx);
});
cx.background_executor.run_until_parked();
editor_cx.update_editor(|editor, window, cx| {
editor.refresh_inline_completion(false, true, window, cx);
});
cx.background_executor.run_until_parked();
editor_cx.update_editor(|editor, window, cx| {
// Verify we have an active completion
assert!(editor.has_active_inline_completion());
// The display text should show the full completion
assert_eq!(editor.display_text(cx), "let items = vec![hello, world]");
// But the actual text should only show what's been typed
assert_eq!(editor.text(cx), "let items = ");
// Accept first partial - should accept "vec" (alphabetic characters)
editor.accept_partial_inline_completion(&Default::default(), window, cx);
// Assert the buffer now contains the first partially accepted text
assert_eq!(editor.text(cx), "let items = vec");
// Completion should still be active for remaining text
assert!(editor.has_active_inline_completion());
// Accept second partial - should accept "![" (non-alphabetic characters)
editor.accept_partial_inline_completion(&Default::default(), window, cx);
// Assert the buffer now contains both partial acceptances
assert_eq!(editor.text(cx), "let items = vec![");
// Completion should still be active for remaining text
assert!(editor.has_active_inline_completion());
});
}
#[gpui::test]
async fn test_completion_invalidation(cx: &mut TestAppContext) {
init_test(cx);
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
// Create fake HTTP client and set up global service
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
fake_http_client.set_generate_response("bar");
// Create global Ollama service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
cx.update(|cx| {
OllamaService::set_global(service.clone(), cx);
});
// Create provider
let provider = cx.update(|cx| {
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
});
// Set up the editor with the Ollama provider
editor_cx.update_editor(|editor, window, cx| {
editor.set_edit_prediction_provider(Some(provider.clone()), window, cx);
});
editor_cx.set_state("fooˇ");
// Trigger the completion through the provider
let buffer =
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
let cursor_position = editor_cx.buffer_snapshot().anchor_after(3); // After "foo"
provider.update(cx, |provider, cx| {
provider.refresh(None, buffer, cursor_position, false, cx);
});
cx.background_executor.run_until_parked();
editor_cx.update_editor(|editor, window, cx| {
editor.refresh_inline_completion(false, true, window, cx);
});
cx.background_executor.run_until_parked();
editor_cx.update_editor(|editor, window, cx| {
assert!(editor.has_active_inline_completion());
assert_eq!(editor.display_text(cx), "foobar");
assert_eq!(editor.text(cx), "foo");
// Backspace within the original text - completion should remain
editor.backspace(&Default::default(), window, cx);
assert!(editor.has_active_inline_completion());
assert_eq!(editor.display_text(cx), "fobar");
assert_eq!(editor.text(cx), "fo");
editor.backspace(&Default::default(), window, cx);
assert!(editor.has_active_inline_completion());
assert_eq!(editor.display_text(cx), "fbar");
assert_eq!(editor.text(cx), "f");
// This backspace removes all original text - should invalidate completion
editor.backspace(&Default::default(), window, cx);
assert!(!editor.has_active_inline_completion());
assert_eq!(editor.display_text(cx), "");
assert_eq!(editor.text(cx), "");
});
}
#[gpui::test]
async fn test_settings_model_merging(cx: &mut TestAppContext) {
init_test(cx);
// Create fake HTTP client that returns some API models
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
// Mock /api/tags response (list models)
let models_response = serde_json::json!({
"models": [
{
"name": "api-model-1",
"modified_at": "2024-01-01T00:00:00Z",
"size": 1000000,
"digest": "abc123",
"details": {
"format": "gguf",
"family": "llama",
"families": ["llama"],
"parameter_size": "7B",
"quantization_level": "Q4_0"
}
},
{
"name": "shared-model",
"modified_at": "2024-01-01T00:00:00Z",
"size": 2000000,
"digest": "def456",
"details": {
"format": "gguf",
"family": "llama",
"families": ["llama"],
"parameter_size": "13B",
"quantization_level": "Q4_0"
}
}
]
});
fake_http_client.set_response("/api/tags", models_response.to_string());
// Mock /api/show responses for each model
let show_response = serde_json::json!({
"capabilities": ["tools", "vision"]
});
fake_http_client.set_response("/api/show", show_response.to_string());
// Create service
let service = cx.update(|cx| {
OllamaService::new(
fake_http_client.clone(),
"http://localhost:11434".to_string(),
cx,
)
});
// Add settings models (including one that overlaps with API)
let settings_models = vec![
SettingsModel {
name: "custom-model-1".to_string(),
display_name: Some("Custom Model 1".to_string()),
max_tokens: 4096,
supports_tools: Some(true),
supports_images: Some(false),
supports_thinking: Some(false),
},
SettingsModel {
name: "shared-model".to_string(), // This should take precedence over API
display_name: Some("Custom Shared Model".to_string()),
max_tokens: 8192,
supports_tools: Some(true),
supports_images: Some(true),
supports_thinking: Some(true),
},
];
cx.update(|cx| {
service.update(cx, |service, cx| {
service.set_settings_models(settings_models, cx);
});
});
// Wait for models to be fetched and merged
cx.run_until_parked();
// Verify merged models
let models = cx.update(|cx| service.read(cx).available_models().to_vec());
assert_eq!(models.len(), 3); // 2 settings models + 1 unique API model
// Models should be sorted alphabetically, so check by name
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert_eq!(
model_names,
vec!["api-model-1", "custom-model-1", "shared-model"]
);
// Check custom model from settings
let custom_model = models.iter().find(|m| m.name == "custom-model-1").unwrap();
assert_eq!(
custom_model.display_name,
Some("Custom Model 1".to_string())
);
assert_eq!(custom_model.max_tokens, 4096);
// Settings model should override API model for shared-model
let shared_model = models.iter().find(|m| m.name == "shared-model").unwrap();
assert_eq!(
shared_model.display_name,
Some("Custom Shared Model".to_string())
);
assert_eq!(shared_model.max_tokens, 8192);
assert_eq!(shared_model.supports_tools, Some(true));
assert_eq!(shared_model.supports_vision, Some(true));
assert_eq!(shared_model.supports_thinking, Some(true));
// API-only model should be included
let api_model = models.iter().find(|m| m.name == "api-model-1").unwrap();
assert!(api_model.display_name.is_none()); // API models don't have custom display names
}
}