
This allows for the scenario where the user doesn't have access to ollama's listing and needs to tell Zed explicitly, by hand
1106 lines
38 KiB
Rust
1106 lines
38 KiB
Rust
use crate::{GenerateOptions, GenerateRequest, Model, generate};
|
|
use anyhow::{Context as AnyhowContext, Result};
|
|
use futures::StreamExt;
|
|
use std::{path::Path, sync::Arc, time::Duration};
|
|
|
|
use gpui::{App, AppContext, Context, Entity, EntityId, Global, Subscription, Task};
|
|
use http_client::HttpClient;
|
|
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
|
use language::{Anchor, Buffer, ToOffset};
|
|
use settings::SettingsStore;
|
|
|
|
use project::Project;
|
|
|
|
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
|
|
|
// Structure for passing settings model data without circular dependencies
|
|
#[derive(Clone, Debug)]
|
|
pub struct SettingsModel {
|
|
pub name: String,
|
|
pub display_name: Option<String>,
|
|
pub max_tokens: u64,
|
|
pub supports_tools: Option<bool>,
|
|
pub supports_images: Option<bool>,
|
|
pub supports_thinking: Option<bool>,
|
|
}
|
|
|
|
impl SettingsModel {
|
|
pub fn to_model(&self) -> Model {
|
|
Model::new(
|
|
&self.name,
|
|
self.display_name.as_deref(),
|
|
Some(self.max_tokens),
|
|
self.supports_tools,
|
|
self.supports_images,
|
|
self.supports_thinking,
|
|
)
|
|
}
|
|
}
|
|
|
|
// Global Ollama service for managing models across all providers
|
|
pub struct OllamaService {
|
|
http_client: Arc<dyn HttpClient>,
|
|
api_url: String,
|
|
available_models: Vec<Model>,
|
|
fetch_models_task: Option<Task<Result<()>>>,
|
|
_settings_subscription: Subscription,
|
|
}
|
|
|
|
impl OllamaService {
|
|
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, cx: &mut App) -> Entity<Self> {
|
|
cx.new(|cx| {
|
|
let subscription = cx.observe_global::<SettingsStore>({
|
|
move |this: &mut OllamaService, cx| {
|
|
this.restart_fetch_models_task(cx);
|
|
}
|
|
});
|
|
|
|
let mut service = Self {
|
|
http_client,
|
|
api_url,
|
|
available_models: Vec::new(),
|
|
fetch_models_task: None,
|
|
_settings_subscription: subscription,
|
|
};
|
|
|
|
service.restart_fetch_models_task(cx);
|
|
service
|
|
})
|
|
}
|
|
|
|
pub fn global(cx: &App) -> Option<Entity<Self>> {
|
|
cx.try_global::<GlobalOllamaService>()
|
|
.map(|service| service.0.clone())
|
|
}
|
|
|
|
pub fn set_global(service: Entity<Self>, cx: &mut App) {
|
|
cx.set_global(GlobalOllamaService(service));
|
|
}
|
|
|
|
pub fn available_models(&self) -> &[Model] {
|
|
&self.available_models
|
|
}
|
|
|
|
pub fn refresh_models(&mut self, cx: &mut Context<Self>) {
|
|
self.restart_fetch_models_task(cx);
|
|
}
|
|
|
|
pub fn set_settings_models(
|
|
&mut self,
|
|
settings_models: Vec<SettingsModel>,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
// Convert settings models to our Model type
|
|
self.available_models = settings_models
|
|
.into_iter()
|
|
.map(|settings_model| settings_model.to_model())
|
|
.collect();
|
|
self.restart_fetch_models_task(cx);
|
|
}
|
|
|
|
fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
|
|
self.fetch_models_task = Some(self.fetch_models(cx));
|
|
}
|
|
|
|
fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
|
|
let http_client = Arc::clone(&self.http_client);
|
|
let api_url = self.api_url.clone();
|
|
|
|
cx.spawn(async move |this, cx| {
|
|
// Get the current settings models to merge with API models
|
|
let settings_models = this.update(cx, |this, _cx| {
|
|
// Get just the names of models from settings to avoid duplicates
|
|
this.available_models
|
|
.iter()
|
|
.map(|m| m.name.clone())
|
|
.collect::<std::collections::HashSet<_>>()
|
|
})?;
|
|
|
|
// Fetch models from API
|
|
let api_models = match crate::get_models(http_client.as_ref(), &api_url, None).await {
|
|
Ok(models) => models,
|
|
Err(_) => return Ok(()), // Silently fail if API is unavailable
|
|
};
|
|
|
|
let tasks = api_models
|
|
.into_iter()
|
|
// Filter out embedding models
|
|
.filter(|model| !model.name.contains("-embed"))
|
|
// Filter out models that are already defined in settings
|
|
.filter(|model| !settings_models.contains(&model.name))
|
|
.map(|model| {
|
|
let http_client = Arc::clone(&http_client);
|
|
let api_url = api_url.clone();
|
|
async move {
|
|
let name = model.name.as_str();
|
|
let capabilities =
|
|
crate::show_model(http_client.as_ref(), &api_url, name).await?;
|
|
let ollama_model = Model::new(
|
|
name,
|
|
None,
|
|
None,
|
|
Some(capabilities.supports_tools()),
|
|
Some(capabilities.supports_vision()),
|
|
Some(capabilities.supports_thinking()),
|
|
);
|
|
Ok(ollama_model)
|
|
}
|
|
});
|
|
|
|
// Rate-limit capability fetches for API-discovered models
|
|
let api_discovered_models: Vec<_> = futures::stream::iter(tasks)
|
|
.buffer_unordered(5)
|
|
.collect::<Vec<Result<_>>>()
|
|
.await
|
|
.into_iter()
|
|
.collect::<Result<Vec<_>>>()
|
|
.unwrap_or_default();
|
|
|
|
this.update(cx, |this, cx| {
|
|
// Append API-discovered models to existing settings models
|
|
this.available_models.extend(api_discovered_models);
|
|
// Sort all models by name
|
|
this.available_models.sort_by(|a, b| a.name.cmp(&b.name));
|
|
cx.notify();
|
|
})?;
|
|
|
|
Ok(())
|
|
})
|
|
}
|
|
}
|
|
|
|
struct GlobalOllamaService(Entity<OllamaService>);
|
|
|
|
impl Global for GlobalOllamaService {}
|
|
|
|
pub struct OllamaCompletionProvider {
|
|
model: String,
|
|
buffer_id: Option<EntityId>,
|
|
file_extension: Option<String>,
|
|
current_completion: Option<String>,
|
|
pending_refresh: Option<Task<Result<()>>>,
|
|
api_key: Option<String>,
|
|
_service_subscription: Option<Subscription>,
|
|
}
|
|
|
|
impl OllamaCompletionProvider {
|
|
pub fn new(model: String, api_key: Option<String>, cx: &mut Context<Self>) -> Self {
|
|
let subscription = if let Some(service) = OllamaService::global(cx) {
|
|
Some(cx.observe(&service, |_this, _service, cx| {
|
|
cx.notify();
|
|
}))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
Self {
|
|
model,
|
|
buffer_id: None,
|
|
file_extension: None,
|
|
current_completion: None,
|
|
pending_refresh: None,
|
|
api_key,
|
|
_service_subscription: subscription,
|
|
}
|
|
}
|
|
|
|
pub fn available_models(&self, cx: &App) -> Vec<Model> {
|
|
if let Some(service) = OllamaService::global(cx) {
|
|
service.read(cx).available_models().to_vec()
|
|
} else {
|
|
Vec::new()
|
|
}
|
|
}
|
|
|
|
pub fn refresh_models(&self, cx: &mut App) {
|
|
if let Some(service) = OllamaService::global(cx) {
|
|
service.update(cx, |service, cx| {
|
|
service.refresh_models(cx);
|
|
});
|
|
}
|
|
}
|
|
|
|
/// Updates the model used by this provider
|
|
pub fn update_model(&mut self, model: String) {
|
|
self.model = model;
|
|
}
|
|
|
|
/// Updates the file extension used by this provider
|
|
pub fn update_file_extension(&mut self, new_file_extension: String) {
|
|
self.file_extension = Some(new_file_extension);
|
|
}
|
|
|
|
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
|
|
let cursor_offset = cursor_position.to_offset(buffer);
|
|
let text = buffer.text();
|
|
|
|
// Get reasonable context around cursor
|
|
let context_size = 2000; // 2KB before and after cursor
|
|
|
|
let start = cursor_offset.saturating_sub(context_size);
|
|
let end = (cursor_offset + context_size).min(text.len());
|
|
|
|
let prefix = text[start..cursor_offset].to_string();
|
|
let suffix = text[cursor_offset..end].to_string();
|
|
|
|
(prefix, suffix)
|
|
}
|
|
|
|
/// Get stop tokens for the current model
|
|
/// For now we only handle the case for codellama:7b-code model
|
|
/// that we found was including the stop token in the completion suggestion.
|
|
/// We wanted to avoid going down this route and let Ollama abstract all template tokens away.
|
|
/// But apparently, and surprisingly for a llama model, Ollama misses this case.
|
|
fn get_stop_tokens(&self) -> Option<Vec<String>> {
|
|
if self.model.contains("codellama") && self.model.contains("code") {
|
|
Some(vec!["<EOT>".to_string()])
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
impl EditPredictionProvider for OllamaCompletionProvider {
|
|
fn name() -> &'static str {
|
|
"ollama"
|
|
}
|
|
|
|
fn display_name() -> &'static str {
|
|
"Ollama"
|
|
}
|
|
|
|
fn show_completions_in_menu() -> bool {
|
|
true
|
|
}
|
|
|
|
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
|
|
true
|
|
}
|
|
|
|
fn is_refreshing(&self) -> bool {
|
|
self.pending_refresh.is_some()
|
|
}
|
|
|
|
fn refresh(
|
|
&mut self,
|
|
project: Option<Entity<Project>>,
|
|
buffer: Entity<Buffer>,
|
|
cursor_position: Anchor,
|
|
debounce: bool,
|
|
cx: &mut Context<Self>,
|
|
) {
|
|
// Get API settings from the global Ollama service or fallback
|
|
let (http_client, api_url) = if let Some(service) = OllamaService::global(cx) {
|
|
let service_ref = service.read(cx);
|
|
(service_ref.http_client.clone(), service_ref.api_url.clone())
|
|
} else {
|
|
// Fallback if global service isn't available
|
|
(
|
|
project
|
|
.as_ref()
|
|
.map(|p| p.read(cx).client().http_client() as Arc<dyn HttpClient>)
|
|
.unwrap_or_else(|| {
|
|
Arc::new(http_client::BlockedHttpClient::new()) as Arc<dyn HttpClient>
|
|
}),
|
|
crate::OLLAMA_API_URL.to_string(),
|
|
)
|
|
};
|
|
|
|
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
|
if debounce {
|
|
cx.background_executor()
|
|
.timer(OLLAMA_DEBOUNCE_TIMEOUT)
|
|
.await;
|
|
}
|
|
|
|
let (prefix, suffix) = this.update(cx, |this, cx| {
|
|
let buffer_snapshot = buffer.read(cx);
|
|
this.buffer_id = Some(buffer.entity_id());
|
|
this.file_extension = buffer_snapshot.file().and_then(|file| {
|
|
Some(
|
|
Path::new(file.file_name(cx))
|
|
.extension()?
|
|
.to_str()?
|
|
.to_string(),
|
|
)
|
|
});
|
|
this.extract_context(buffer_snapshot, cursor_position)
|
|
})?;
|
|
|
|
let (model, api_key) =
|
|
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
|
|
|
|
let stop_tokens = this.update(cx, |this, _| this.get_stop_tokens())?;
|
|
|
|
let request = GenerateRequest {
|
|
model,
|
|
prompt: prefix,
|
|
suffix: Some(suffix),
|
|
stream: false,
|
|
options: Some(GenerateOptions {
|
|
num_predict: Some(150), // Reasonable completion length
|
|
temperature: Some(0.1), // Low temperature for more deterministic results
|
|
top_p: Some(0.95),
|
|
stop: stop_tokens,
|
|
}),
|
|
keep_alive: None,
|
|
context: None,
|
|
};
|
|
|
|
let response = generate(http_client.as_ref(), &api_url, api_key, request)
|
|
.await
|
|
.context("Failed to get completion from Ollama");
|
|
|
|
this.update(cx, |this, cx| {
|
|
this.pending_refresh = None;
|
|
match response {
|
|
Ok(response) if !response.response.trim().is_empty() => {
|
|
this.current_completion = Some(response.response);
|
|
}
|
|
_ => {
|
|
this.current_completion = None;
|
|
}
|
|
}
|
|
cx.notify();
|
|
})?;
|
|
|
|
Ok(())
|
|
}));
|
|
}
|
|
|
|
fn cycle(
|
|
&mut self,
|
|
_buffer: Entity<Buffer>,
|
|
_cursor_position: Anchor,
|
|
_direction: Direction,
|
|
_cx: &mut Context<Self>,
|
|
) {
|
|
// Ollama doesn't provide multiple completions in a single request
|
|
// Could be implemented by making multiple requests with different temperatures
|
|
// or by using different models
|
|
}
|
|
|
|
fn accept(&mut self, _cx: &mut Context<Self>) {
|
|
self.current_completion = None;
|
|
// TODO: Could send accept telemetry to Ollama if supported
|
|
}
|
|
|
|
fn discard(&mut self, _cx: &mut Context<Self>) {
|
|
self.current_completion = None;
|
|
// TODO: Could send discard telemetry to Ollama if supported
|
|
}
|
|
|
|
fn suggest(
|
|
&mut self,
|
|
buffer: &Entity<Buffer>,
|
|
cursor_position: Anchor,
|
|
cx: &mut Context<Self>,
|
|
) -> Option<InlineCompletion> {
|
|
let buffer_id = buffer.entity_id();
|
|
if Some(buffer_id) != self.buffer_id {
|
|
return None;
|
|
}
|
|
|
|
let completion_text = self.current_completion.as_ref()?.clone();
|
|
|
|
if completion_text.trim().is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let buffer_snapshot = buffer.read(cx);
|
|
let cursor_offset = cursor_position.to_offset(buffer_snapshot);
|
|
|
|
// Get text before cursor to check what's already been typed
|
|
let text_before_cursor = buffer_snapshot
|
|
.text_for_range(0..cursor_offset)
|
|
.collect::<String>();
|
|
|
|
// Find how much of the completion has already been typed by checking
|
|
// if the text before the cursor ends with a prefix of our completion
|
|
let mut prefix_len = 0;
|
|
for i in 1..=completion_text.len().min(text_before_cursor.len()) {
|
|
if text_before_cursor.ends_with(&completion_text[..i]) {
|
|
prefix_len = i;
|
|
}
|
|
}
|
|
|
|
// Only suggest the remaining part of the completion
|
|
let remaining_completion = &completion_text[prefix_len..];
|
|
|
|
if remaining_completion.trim().is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let position = cursor_position.bias_right(buffer_snapshot);
|
|
|
|
Some(InlineCompletion {
|
|
id: None,
|
|
edits: vec![(position..position, remaining_completion.to_string())],
|
|
edit_preview: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
use gpui::{AppContext, TestAppContext};
|
|
|
|
use client;
|
|
use language::Buffer;
|
|
use project::Project;
|
|
use settings::SettingsStore;
|
|
|
|
fn init_test(cx: &mut TestAppContext) {
|
|
cx.update(|cx| {
|
|
let settings_store = SettingsStore::test(cx);
|
|
cx.set_global(settings_store);
|
|
theme::init(theme::LoadThemes::JustBase, cx);
|
|
client::init_settings(cx);
|
|
language::init(cx);
|
|
editor::init_settings(cx);
|
|
Project::init_settings(cx);
|
|
workspace::init_settings(cx);
|
|
});
|
|
}
|
|
|
|
/// Test the complete Ollama completion flow from refresh to suggestion
|
|
#[gpui::test]
|
|
fn test_get_stop_tokens(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
// Test CodeLlama code model gets stop tokens
|
|
let codellama_provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("codellama:7b-code".to_string(), None, cx))
|
|
});
|
|
|
|
codellama_provider.read_with(cx, |provider, _| {
|
|
assert_eq!(provider.get_stop_tokens(), Some(vec!["<EOT>".to_string()]));
|
|
});
|
|
|
|
// Test non-CodeLlama model doesn't get stop tokens
|
|
let qwen_provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
qwen_provider.read_with(cx, |provider, _| {
|
|
assert_eq!(provider.get_stop_tokens(), None);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_model_discovery(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
// Create fake HTTP client
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
|
|
// Mock /api/tags response (list models)
|
|
let models_response = serde_json::json!({
|
|
"models": [
|
|
{
|
|
"name": "qwen2.5-coder:3b",
|
|
"modified_at": "2024-01-01T00:00:00Z",
|
|
"size": 1000000,
|
|
"digest": "abc123",
|
|
"details": {
|
|
"format": "gguf",
|
|
"family": "qwen2",
|
|
"families": ["qwen2"],
|
|
"parameter_size": "3B",
|
|
"quantization_level": "Q4_0"
|
|
}
|
|
},
|
|
{
|
|
"name": "codellama:7b-code",
|
|
"modified_at": "2024-01-01T00:00:00Z",
|
|
"size": 2000000,
|
|
"digest": "def456",
|
|
"details": {
|
|
"format": "gguf",
|
|
"family": "codellama",
|
|
"families": ["codellama"],
|
|
"parameter_size": "7B",
|
|
"quantization_level": "Q4_0"
|
|
}
|
|
},
|
|
{
|
|
"name": "nomic-embed-text",
|
|
"modified_at": "2024-01-01T00:00:00Z",
|
|
"size": 500000,
|
|
"digest": "ghi789",
|
|
"details": {
|
|
"format": "gguf",
|
|
"family": "nomic-embed",
|
|
"families": ["nomic-embed"],
|
|
"parameter_size": "137M",
|
|
"quantization_level": "Q4_0"
|
|
}
|
|
}
|
|
]
|
|
});
|
|
|
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
|
|
|
// Mock /api/show responses for model capabilities
|
|
let qwen_capabilities = serde_json::json!({
|
|
"capabilities": ["tools", "thinking"]
|
|
});
|
|
|
|
let _codellama_capabilities = serde_json::json!({
|
|
"capabilities": []
|
|
});
|
|
|
|
fake_http_client.set_response("/api/show", qwen_capabilities.to_string());
|
|
|
|
// Create global Ollama service for testing
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
// Set it as global
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
// Create completion provider
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
// Wait for model discovery to complete
|
|
cx.background_executor.run_until_parked();
|
|
|
|
// Verify models were discovered through the global provider
|
|
provider.read_with(cx, |provider, cx| {
|
|
let models = provider.available_models(cx);
|
|
assert_eq!(models.len(), 2); // Should exclude nomic-embed-text
|
|
|
|
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
|
assert!(model_names.contains(&"codellama:7b-code"));
|
|
assert!(model_names.contains(&"qwen2.5-coder:3b"));
|
|
assert!(!model_names.contains(&"nomic-embed-text"));
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_model_discovery_api_failure(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
// Create fake HTTP client that returns errors
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
fake_http_client.set_error("Connection refused");
|
|
|
|
// Create global Ollama service that will fail
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
// Create completion provider
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
// Wait for model discovery to complete (with failure)
|
|
cx.background_executor.run_until_parked();
|
|
|
|
// Verify graceful handling - should have empty model list
|
|
provider.read_with(cx, |provider, cx| {
|
|
let models = provider.available_models(cx);
|
|
assert_eq!(models.len(), 0);
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_refresh_models(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
|
|
// Initially return empty model list
|
|
let empty_response = serde_json::json!({"models": []});
|
|
fake_http_client.set_response("/api/tags", empty_response.to_string());
|
|
|
|
// Create global Ollama service
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:7b".to_string(), None, cx))
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
// Verify initially empty
|
|
provider.read_with(cx, |provider, cx| {
|
|
assert_eq!(provider.available_models(cx).len(), 0);
|
|
});
|
|
|
|
// Update mock to return models
|
|
let models_response = serde_json::json!({
|
|
"models": [
|
|
{
|
|
"name": "qwen2.5-coder:7b",
|
|
"modified_at": "2024-01-01T00:00:00Z",
|
|
"size": 1000000,
|
|
"digest": "abc123",
|
|
"details": {
|
|
"format": "gguf",
|
|
"family": "qwen2",
|
|
"families": ["qwen2"],
|
|
"parameter_size": "7B",
|
|
"quantization_level": "Q4_0"
|
|
}
|
|
}
|
|
]
|
|
});
|
|
|
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
|
|
|
let capabilities = serde_json::json!({
|
|
"capabilities": ["tools", "thinking"]
|
|
});
|
|
|
|
fake_http_client.set_response("/api/show", capabilities.to_string());
|
|
|
|
// Trigger refresh
|
|
provider.update(cx, |provider, cx| {
|
|
provider.refresh_models(cx);
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
// Verify models were refreshed
|
|
provider.read_with(cx, |provider, cx| {
|
|
let models = provider.available_models(cx);
|
|
assert_eq!(models.len(), 1);
|
|
assert_eq!(models[0].name, "qwen2.5-coder:7b");
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_full_completion_flow(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
// Create a buffer with realistic code content
|
|
let buffer = cx.update(|cx| cx.new(|cx| Buffer::local("fn test() {\n \n}", cx)));
|
|
let cursor_position = buffer.read_with(cx, |buffer, _| {
|
|
buffer.anchor_before(11) // Position in the middle of the function
|
|
});
|
|
|
|
// Create fake HTTP client and set up global service
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
fake_http_client.set_generate_response("println!(\"Hello\");");
|
|
|
|
// Create global Ollama service
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
// Create provider
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
// Trigger completion refresh (no debounce for test speed)
|
|
provider.update(cx, |provider, cx| {
|
|
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
|
|
});
|
|
|
|
// Wait for completion task to complete
|
|
cx.background_executor.run_until_parked();
|
|
|
|
// Verify completion was processed and stored
|
|
provider.read_with(cx, |provider, _cx| {
|
|
assert!(provider.current_completion.is_some());
|
|
assert_eq!(
|
|
provider.current_completion.as_ref().unwrap(),
|
|
"println!(\"Hello\");"
|
|
);
|
|
assert!(!provider.is_refreshing());
|
|
});
|
|
|
|
// Test suggestion logic returns the completion
|
|
let suggestion = cx.update(|cx| {
|
|
provider.update(cx, |provider, cx| {
|
|
provider.suggest(&buffer, cursor_position, cx)
|
|
})
|
|
});
|
|
|
|
assert!(suggestion.is_some());
|
|
let suggestion = suggestion.unwrap();
|
|
assert_eq!(suggestion.edits.len(), 1);
|
|
assert_eq!(suggestion.edits[0].1, "println!(\"Hello\");");
|
|
|
|
// Verify acceptance clears the completion
|
|
provider.update(cx, |provider, cx| {
|
|
provider.accept(cx);
|
|
});
|
|
|
|
provider.read_with(cx, |provider, _cx| {
|
|
assert!(provider.current_completion.is_none());
|
|
});
|
|
}
|
|
|
|
/// Test that partial typing is handled correctly - only suggests untyped portion
|
|
#[gpui::test]
|
|
async fn test_partial_typing_handling(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
// Create buffer where user has partially typed "vec"
|
|
let buffer = cx.update(|cx| cx.new(|cx| Buffer::local("let result = vec", cx)));
|
|
let cursor_position = buffer.read_with(cx, |buffer, _| {
|
|
buffer.anchor_after(16) // After "vec"
|
|
});
|
|
|
|
// Create fake HTTP client and set up global service
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
|
|
// Create global Ollama service
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
// Create provider
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
// Configure response that starts with what user already typed
|
|
fake_http_client.set_generate_response("vec![1, 2, 3]");
|
|
|
|
provider.update(cx, |provider, cx| {
|
|
provider.refresh(None, buffer.clone(), cursor_position, false, cx);
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
// Should suggest only the remaining part after "vec"
|
|
let suggestion = cx.update(|cx| {
|
|
provider.update(cx, |provider, cx| {
|
|
provider.suggest(&buffer, cursor_position, cx)
|
|
})
|
|
});
|
|
|
|
// Verify we get a reasonable suggestion
|
|
if let Some(suggestion) = suggestion {
|
|
assert_eq!(suggestion.edits.len(), 1);
|
|
assert!(suggestion.edits[0].1.contains("1, 2, 3"));
|
|
}
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_accept_partial_ollama_suggestion(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
|
|
|
// Create fake HTTP client and set up global service
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
fake_http_client.set_generate_response("vec![hello, world]");
|
|
|
|
// Create global Ollama service
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
// Create provider
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
// Set up the editor with the Ollama provider
|
|
editor_cx.update_editor(|editor, window, cx| {
|
|
editor.set_edit_prediction_provider(Some(provider.clone()), window, cx);
|
|
});
|
|
|
|
// Set initial state
|
|
editor_cx.set_state("let items = ˇ");
|
|
|
|
// Trigger the completion through the provider
|
|
let buffer =
|
|
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
|
let cursor_position = editor_cx.buffer_snapshot().anchor_after(12);
|
|
|
|
provider.update(cx, |provider, cx| {
|
|
provider.refresh(None, buffer, cursor_position, false, cx);
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
editor_cx.update_editor(|editor, window, cx| {
|
|
editor.refresh_inline_completion(false, true, window, cx);
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
editor_cx.update_editor(|editor, window, cx| {
|
|
// Verify we have an active completion
|
|
assert!(editor.has_active_inline_completion());
|
|
|
|
// The display text should show the full completion
|
|
assert_eq!(editor.display_text(cx), "let items = vec![hello, world]");
|
|
// But the actual text should only show what's been typed
|
|
assert_eq!(editor.text(cx), "let items = ");
|
|
|
|
// Accept first partial - should accept "vec" (alphabetic characters)
|
|
editor.accept_partial_inline_completion(&Default::default(), window, cx);
|
|
|
|
// Assert the buffer now contains the first partially accepted text
|
|
assert_eq!(editor.text(cx), "let items = vec");
|
|
// Completion should still be active for remaining text
|
|
assert!(editor.has_active_inline_completion());
|
|
|
|
// Accept second partial - should accept "![" (non-alphabetic characters)
|
|
editor.accept_partial_inline_completion(&Default::default(), window, cx);
|
|
|
|
// Assert the buffer now contains both partial acceptances
|
|
assert_eq!(editor.text(cx), "let items = vec![");
|
|
// Completion should still be active for remaining text
|
|
assert!(editor.has_active_inline_completion());
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_completion_invalidation(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
let mut editor_cx = editor::test::editor_test_context::EditorTestContext::new(cx).await;
|
|
|
|
// Create fake HTTP client and set up global service
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
fake_http_client.set_generate_response("bar");
|
|
|
|
// Create global Ollama service
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
cx.update(|cx| {
|
|
OllamaService::set_global(service.clone(), cx);
|
|
});
|
|
|
|
// Create provider
|
|
let provider = cx.update(|cx| {
|
|
cx.new(|cx| OllamaCompletionProvider::new("qwen2.5-coder:3b".to_string(), None, cx))
|
|
});
|
|
|
|
// Set up the editor with the Ollama provider
|
|
editor_cx.update_editor(|editor, window, cx| {
|
|
editor.set_edit_prediction_provider(Some(provider.clone()), window, cx);
|
|
});
|
|
|
|
editor_cx.set_state("fooˇ");
|
|
|
|
// Trigger the completion through the provider
|
|
let buffer =
|
|
editor_cx.multibuffer(|multibuffer, _| multibuffer.as_singleton().unwrap().clone());
|
|
let cursor_position = editor_cx.buffer_snapshot().anchor_after(3); // After "foo"
|
|
|
|
provider.update(cx, |provider, cx| {
|
|
provider.refresh(None, buffer, cursor_position, false, cx);
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
editor_cx.update_editor(|editor, window, cx| {
|
|
editor.refresh_inline_completion(false, true, window, cx);
|
|
});
|
|
|
|
cx.background_executor.run_until_parked();
|
|
|
|
editor_cx.update_editor(|editor, window, cx| {
|
|
assert!(editor.has_active_inline_completion());
|
|
assert_eq!(editor.display_text(cx), "foobar");
|
|
assert_eq!(editor.text(cx), "foo");
|
|
|
|
// Backspace within the original text - completion should remain
|
|
editor.backspace(&Default::default(), window, cx);
|
|
assert!(editor.has_active_inline_completion());
|
|
assert_eq!(editor.display_text(cx), "fobar");
|
|
assert_eq!(editor.text(cx), "fo");
|
|
|
|
editor.backspace(&Default::default(), window, cx);
|
|
assert!(editor.has_active_inline_completion());
|
|
assert_eq!(editor.display_text(cx), "fbar");
|
|
assert_eq!(editor.text(cx), "f");
|
|
|
|
// This backspace removes all original text - should invalidate completion
|
|
editor.backspace(&Default::default(), window, cx);
|
|
assert!(!editor.has_active_inline_completion());
|
|
assert_eq!(editor.display_text(cx), "");
|
|
assert_eq!(editor.text(cx), "");
|
|
});
|
|
}
|
|
|
|
#[gpui::test]
|
|
async fn test_settings_model_merging(cx: &mut TestAppContext) {
|
|
init_test(cx);
|
|
|
|
// Create fake HTTP client that returns some API models
|
|
let fake_http_client = Arc::new(crate::fake::FakeHttpClient::new());
|
|
|
|
// Mock /api/tags response (list models)
|
|
let models_response = serde_json::json!({
|
|
"models": [
|
|
{
|
|
"name": "api-model-1",
|
|
"modified_at": "2024-01-01T00:00:00Z",
|
|
"size": 1000000,
|
|
"digest": "abc123",
|
|
"details": {
|
|
"format": "gguf",
|
|
"family": "llama",
|
|
"families": ["llama"],
|
|
"parameter_size": "7B",
|
|
"quantization_level": "Q4_0"
|
|
}
|
|
},
|
|
{
|
|
"name": "shared-model",
|
|
"modified_at": "2024-01-01T00:00:00Z",
|
|
"size": 2000000,
|
|
"digest": "def456",
|
|
"details": {
|
|
"format": "gguf",
|
|
"family": "llama",
|
|
"families": ["llama"],
|
|
"parameter_size": "13B",
|
|
"quantization_level": "Q4_0"
|
|
}
|
|
}
|
|
]
|
|
});
|
|
|
|
fake_http_client.set_response("/api/tags", models_response.to_string());
|
|
|
|
// Mock /api/show responses for each model
|
|
let show_response = serde_json::json!({
|
|
"capabilities": ["tools", "vision"]
|
|
});
|
|
fake_http_client.set_response("/api/show", show_response.to_string());
|
|
|
|
// Create service
|
|
let service = cx.update(|cx| {
|
|
OllamaService::new(
|
|
fake_http_client.clone(),
|
|
"http://localhost:11434".to_string(),
|
|
cx,
|
|
)
|
|
});
|
|
|
|
// Add settings models (including one that overlaps with API)
|
|
let settings_models = vec![
|
|
SettingsModel {
|
|
name: "custom-model-1".to_string(),
|
|
display_name: Some("Custom Model 1".to_string()),
|
|
max_tokens: 4096,
|
|
supports_tools: Some(true),
|
|
supports_images: Some(false),
|
|
supports_thinking: Some(false),
|
|
},
|
|
SettingsModel {
|
|
name: "shared-model".to_string(), // This should take precedence over API
|
|
display_name: Some("Custom Shared Model".to_string()),
|
|
max_tokens: 8192,
|
|
supports_tools: Some(true),
|
|
supports_images: Some(true),
|
|
supports_thinking: Some(true),
|
|
},
|
|
];
|
|
|
|
cx.update(|cx| {
|
|
service.update(cx, |service, cx| {
|
|
service.set_settings_models(settings_models, cx);
|
|
});
|
|
});
|
|
|
|
// Wait for models to be fetched and merged
|
|
cx.run_until_parked();
|
|
|
|
// Verify merged models
|
|
let models = cx.update(|cx| service.read(cx).available_models().to_vec());
|
|
|
|
assert_eq!(models.len(), 3); // 2 settings models + 1 unique API model
|
|
|
|
// Models should be sorted alphabetically, so check by name
|
|
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
|
|
assert_eq!(
|
|
model_names,
|
|
vec!["api-model-1", "custom-model-1", "shared-model"]
|
|
);
|
|
|
|
// Check custom model from settings
|
|
let custom_model = models.iter().find(|m| m.name == "custom-model-1").unwrap();
|
|
assert_eq!(
|
|
custom_model.display_name,
|
|
Some("Custom Model 1".to_string())
|
|
);
|
|
assert_eq!(custom_model.max_tokens, 4096);
|
|
|
|
// Settings model should override API model for shared-model
|
|
let shared_model = models.iter().find(|m| m.name == "shared-model").unwrap();
|
|
assert_eq!(
|
|
shared_model.display_name,
|
|
Some("Custom Shared Model".to_string())
|
|
);
|
|
assert_eq!(shared_model.max_tokens, 8192);
|
|
assert_eq!(shared_model.supports_tools, Some(true));
|
|
assert_eq!(shared_model.supports_vision, Some(true));
|
|
assert_eq!(shared_model.supports_thinking, Some(true));
|
|
|
|
// API-only model should be included
|
|
let api_model = models.iter().find(|m| m.name == "api-model-1").unwrap();
|
|
assert!(api_model.display_name.is_none()); // API models don't have custom display names
|
|
}
|
|
}
|