Support using ollama as an inline_completion_provider

This commit is contained in:
Oliver Azevedo Barnes 2025-06-29 13:29:45 -03:00
parent 047d515abf
commit 72d0b2402a
No known key found for this signature in database
9 changed files with 535 additions and 4 deletions

8
Cargo.lock generated
View file

@ -10807,10 +10807,17 @@ version = "0.1.0"
dependencies = [
"anyhow",
"futures 0.3.31",
"gpui",
"http_client",
"indoc",
"inline_completion",
"language",
"multi_buffer",
"project",
"schemars",
"serde",
"serde_json",
"text",
"workspace-hack",
]
@ -20005,6 +20012,7 @@ dependencies = [
"nix 0.29.0",
"node_runtime",
"notifications",
"ollama",
"outline",
"outline_panel",
"parking_lot",

View file

@ -358,6 +358,41 @@ impl Render for InlineCompletionButton {
div().child(popover_menu.into_any_element())
}
EditPredictionProvider::Ollama => {
let enabled = self.editor_enabled.unwrap_or(false);
let icon = if enabled {
IconName::AiOllama
} else {
IconName::AiOllama // Could add disabled variant
};
let this = cx.entity().clone();
div().child(
PopoverMenu::new("ollama")
.menu(move |window, cx| {
Some(
this.update(cx, |this, cx| {
this.build_ollama_context_menu(window, cx)
}),
)
})
.trigger(
IconButton::new("ollama-completion", icon)
.icon_size(IconSize::Small)
.tooltip(|window, cx| {
Tooltip::for_action(
"Ollama Completion",
&ToggleMenu,
window,
cx,
)
}),
)
.with_handle(self.popover_menu_handle.clone()),
)
}
}
}
}
@ -805,6 +840,26 @@ impl InlineCompletionButton {
})
}
fn build_ollama_context_menu(
&self,
window: &mut Window,
cx: &mut Context<Self>,
) -> Entity<ContextMenu> {
let fs = self.fs.clone();
ContextMenu::build(window, cx, |menu, _window, _cx| {
menu.entry("Toggle Ollama Completions", None, {
let fs = fs.clone();
move |_window, cx| {
toggle_inline_completions_globally(fs.clone(), cx);
}
})
.entry("Ollama Settings...", None, |_window, cx| {
// TODO: Open Ollama-specific settings
cx.open_url("http://localhost:11434");
})
})
}
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
let editor = editor.read(cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);

View file

@ -216,6 +216,7 @@ pub enum EditPredictionProvider {
Copilot,
Supermaven,
Zed,
Ollama,
}
impl EditPredictionProvider {
@ -224,7 +225,8 @@ impl EditPredictionProvider {
EditPredictionProvider::Zed => true,
EditPredictionProvider::None
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven => false,
| EditPredictionProvider::Supermaven
| EditPredictionProvider::Ollama => false,
}
}
}

View file

@ -9,17 +9,34 @@ license = "GPL-3.0-or-later"
workspace = true
[lib]
path = "src/ollama.rs"
path = "src/lib.rs"
[features]
default = []
schemars = ["dep:schemars"]
test-support = [
"gpui/test-support",
"http_client/test-support",
"language/test-support",
]
[dependencies]
anyhow.workspace = true
futures.workspace = true
gpui.workspace = true
http_client.workspace = true
inline_completion.workspace = true
language.workspace = true
multi_buffer.workspace = true
project.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
text.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
gpui = { workspace = true, features = ["test-support"] }
http_client = { workspace = true, features = ["test-support"] }
indoc.workspace = true
language = { workspace = true, features = ["test-support"] }

5
crates/ollama/src/lib.rs Normal file
View file

@ -0,0 +1,5 @@
mod ollama;
mod ollama_completion_provider;
pub use ollama::*;
pub use ollama_completion_provider::*;

View file

@ -98,6 +98,38 @@ impl Model {
}
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateRequest {
pub model: String,
pub prompt: String,
pub stream: bool,
pub options: Option<GenerateOptions>,
pub keep_alive: Option<KeepAlive>,
pub context: Option<Vec<i64>>,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateOptions {
pub num_predict: Option<i32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Debug, Serialize, Deserialize)]
pub struct GenerateResponse {
pub response: String,
pub done: bool,
pub context: Option<Vec<i64>>,
pub total_duration: Option<u64>,
pub load_duration: Option<u64>,
pub prompt_eval_count: Option<i32>,
pub eval_count: Option<i32>,
}
#[derive(Serialize, Deserialize, Debug)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
@ -359,6 +391,36 @@ pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) ->
Ok(details)
}
pub async fn generate(
client: &dyn HttpClient,
api_url: &str,
request: GenerateRequest,
) -> Result<GenerateResponse> {
let uri = format!("{api_url}/api/generate");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
let serialized_request = serde_json::to_string(&request)?;
let request = request_builder.body(AsyncBody::from(serialized_request))?;
let mut response = client.send(request).await?;
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
anyhow::ensure!(
response.status().is_success(),
"Failed to connect to Ollama API: {} {}",
response.status(),
body,
);
let response: GenerateResponse =
serde_json::from_str(&body).context("Unable to parse Ollama generate response")?;
Ok(response)
}
#[cfg(test)]
mod tests {
use super::*;

View file

@ -0,0 +1,363 @@
use crate::{GenerateOptions, GenerateRequest, generate};
use anyhow::{Context as AnyhowContext, Result};
use gpui::{App, Context, Entity, EntityId, Task};
use http_client::HttpClient;
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
use language::{Anchor, Buffer, ToOffset};
use project::Project;
use std::{path::Path, sync::Arc, time::Duration};
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
pub struct OllamaCompletionProvider {
http_client: Arc<dyn HttpClient>,
api_url: String,
model: String,
buffer_id: Option<EntityId>,
file_extension: Option<String>,
current_completion: Option<String>,
pending_refresh: Option<Task<Result<()>>>,
}
impl OllamaCompletionProvider {
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, model: String) -> Self {
Self {
http_client,
api_url,
model,
buffer_id: None,
file_extension: None,
current_completion: None,
pending_refresh: None,
}
}
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
// Use model-specific FIM patterns
match self.model.as_str() {
m if m.contains("codellama") => {
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
}
m if m.contains("deepseek") => {
format!("<fim▁begin>{prefix}<fim▁hole>{suffix}<fim▁end>")
}
m if m.contains("starcoder") => {
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
}
_ => {
// Generic FIM pattern - fallback for models without specific support
format!("// Complete the following code:\n{prefix}\n// COMPLETION HERE\n{suffix}")
}
}
}
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
let cursor_offset = cursor_position.to_offset(buffer);
let text = buffer.text();
// Get reasonable context around cursor
let context_size = 2000; // 2KB before and after cursor
let start = cursor_offset.saturating_sub(context_size);
let end = (cursor_offset + context_size).min(text.len());
let prefix = text[start..cursor_offset].to_string();
let suffix = text[cursor_offset..end].to_string();
(prefix, suffix)
}
}
impl EditPredictionProvider for OllamaCompletionProvider {
fn name() -> &'static str {
"ollama"
}
fn display_name() -> &'static str {
"Ollama"
}
fn show_completions_in_menu() -> bool {
false
}
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
// TODO: Could ping Ollama API to check if it's running
true
}
fn is_refreshing(&self) -> bool {
self.pending_refresh.is_some()
}
fn refresh(
&mut self,
_project: Option<Entity<Project>>,
buffer: Entity<Buffer>,
cursor_position: Anchor,
debounce: bool,
cx: &mut Context<Self>,
) {
let http_client = self.http_client.clone();
let api_url = self.api_url.clone();
let model = self.model.clone();
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
if debounce {
cx.background_executor()
.timer(OLLAMA_DEBOUNCE_TIMEOUT)
.await;
}
let (prefix, suffix) = this.update(cx, |this, cx| {
let buffer_snapshot = buffer.read(cx);
this.buffer_id = Some(buffer.entity_id());
this.file_extension = buffer_snapshot.file().and_then(|file| {
Some(
Path::new(file.file_name(cx))
.extension()?
.to_str()?
.to_string(),
)
});
this.extract_context(buffer_snapshot, cursor_position)
})?;
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
let request = GenerateRequest {
model: model.clone(),
prompt,
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150), // Reasonable completion length
temperature: Some(0.1), // Low temperature for more deterministic results
top_p: Some(0.95),
stop: Some(vec![
"\n\n".to_string(),
"```".to_string(),
"</PRE>".to_string(),
"<SUF>".to_string(),
]),
}),
keep_alive: None,
context: None,
};
let response = generate(http_client.as_ref(), &api_url, request)
.await
.context("Failed to get completion from Ollama")?;
this.update(cx, |this, cx| {
this.pending_refresh = None;
if !response.response.trim().is_empty() {
this.current_completion = Some(response.response);
} else {
this.current_completion = None;
}
cx.notify();
})?;
Ok(())
}));
}
fn cycle(
&mut self,
_buffer: Entity<Buffer>,
_cursor_position: Anchor,
_direction: Direction,
_cx: &mut Context<Self>,
) {
// Ollama doesn't provide multiple completions in a single request
// Could be implemented by making multiple requests with different temperatures
// or by using different models
}
fn accept(&mut self, _cx: &mut Context<Self>) {
self.current_completion = None;
// TODO: Could send accept telemetry to Ollama if supported
}
fn discard(&mut self, _cx: &mut Context<Self>) {
self.current_completion = None;
// TODO: Could send discard telemetry to Ollama if supported
}
fn suggest(
&mut self,
buffer: &Entity<Buffer>,
cursor_position: Anchor,
cx: &mut Context<Self>,
) -> Option<InlineCompletion> {
let buffer_id = buffer.entity_id();
if Some(buffer_id) != self.buffer_id {
return None;
}
let completion_text = self.current_completion.as_ref()?.clone();
if completion_text.trim().is_empty() {
return None;
}
let buffer_snapshot = buffer.read(cx);
let position = cursor_position.bias_right(buffer_snapshot);
// Clean up the completion text
let completion_text = completion_text.trim_start().trim_end();
Some(InlineCompletion {
id: None,
edits: vec![(position..position, completion_text.to_string())],
edit_preview: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::{AppContext, TestAppContext};
use http_client::FakeHttpClient;
use std::sync::Arc;
#[gpui::test]
async fn test_fim_prompt_patterns(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
);
let prefix = "function hello() {";
let suffix = "}";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<PRE>"));
assert!(prompt.contains("<SUF>"));
assert!(prompt.contains("<MID>"));
assert!(prompt.contains(prefix));
assert!(prompt.contains(suffix));
}
#[gpui::test]
async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"deepseek-coder:6.7b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim▁begin>"));
assert!(prompt.contains("<fim▁hole>"));
assert!(prompt.contains("<fim▁end>"));
}
#[gpui::test]
async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"starcoder:7b".to_string(),
);
let prefix = "def hello():";
let suffix = " pass";
let prompt = provider.build_fim_prompt(prefix, suffix);
assert!(prompt.contains("<fim_prefix>"));
assert!(prompt.contains("<fim_suffix>"));
assert!(prompt.contains("<fim_middle>"));
}
#[gpui::test]
async fn test_extract_context(cx: &mut TestAppContext) {
let provider = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
);
// Create a simple buffer using test context
let buffer_text = "function example() {\n let x = 1;\n let y = 2;\n // cursor here\n return x + y;\n}";
let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
// Position cursor at the end of the "// cursor here" line
let (prefix, suffix, _cursor_position) = cx.read(|cx| {
let buffer_snapshot = buffer.read(cx);
let cursor_position = buffer_snapshot.anchor_after(text::Point::new(3, 15)); // End of "// cursor here"
let (prefix, suffix) = provider.extract_context(&buffer_snapshot, cursor_position);
(prefix, suffix, cursor_position)
});
assert!(prefix.contains("function example()"));
assert!(prefix.contains("// cursor h"));
assert!(suffix.contains("ere"));
assert!(suffix.contains("return x + y"));
assert!(suffix.contains("}"));
}
#[gpui::test]
async fn test_suggest_with_completion(cx: &mut TestAppContext) {
let provider = cx.new(|_| {
OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
)
});
let buffer_text = "// test";
let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
// Set up a mock completion
provider.update(cx, |provider, _| {
provider.current_completion = Some("console.log('hello');".to_string());
provider.buffer_id = Some(buffer.entity_id());
});
let cursor_position = cx.read(|cx| buffer.read(cx).anchor_after(text::Point::new(0, 7)));
let completion = provider.update(cx, |provider, cx| {
provider.suggest(&buffer, cursor_position, cx)
});
assert!(completion.is_some());
let completion = completion.unwrap();
assert_eq!(completion.edits.len(), 1);
assert_eq!(completion.edits[0].1, "console.log('hello');");
}
#[gpui::test]
async fn test_suggest_empty_completion(cx: &mut TestAppContext) {
let provider = cx.new(|_| {
OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
)
});
let buffer_text = "// test";
let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
// Set up an empty completion
provider.update(cx, |provider, _| {
provider.current_completion = Some(" ".to_string()); // Only whitespace
provider.buffer_id = Some(buffer.entity_id());
});
let cursor_position = cx.read(|cx| buffer.read(cx).anchor_after(text::Point::new(0, 7)));
let completion = provider.update(cx, |provider, cx| {
provider.suggest(&buffer, cursor_position, cx)
});
assert!(completion.is_none());
}
}

View file

@ -71,6 +71,7 @@ image_viewer.workspace = true
indoc.workspace = true
inline_completion_button.workspace = true
inspector_ui.workspace = true
ollama.workspace = true
install_cli.workspace = true
jj_ui.workspace = true
journal.workspace = true

View file

@ -4,7 +4,9 @@ use copilot::{Copilot, CopilotCompletionProvider};
use editor::Editor;
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
use language::language_settings::{EditPredictionProvider, all_language_settings};
use settings::SettingsStore;
use language_models::AllLanguageModelSettings;
use ollama::OllamaCompletionProvider;
use settings::{Settings, SettingsStore};
use smol::stream::StreamExt;
use std::{cell::RefCell, rc::Rc, sync::Arc};
use supermaven::{Supermaven, SupermavenCompletionProvider};
@ -129,7 +131,8 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
}
EditPredictionProvider::None
| EditPredictionProvider::Copilot
| EditPredictionProvider::Supermaven => {}
| EditPredictionProvider::Supermaven
| EditPredictionProvider::Ollama => {}
}
}
}
@ -283,5 +286,20 @@ fn assign_edit_prediction_provider(
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
EditPredictionProvider::Ollama => {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let api_url = settings.api_url.clone();
// Use first available model or default
let model = settings
.available_models
.first()
.map(|m| m.name.clone())
.unwrap_or_else(|| "codellama:7b".to_string());
let provider =
cx.new(|_| OllamaCompletionProvider::new(client.http_client(), api_url, model));
editor.set_edit_prediction_provider(Some(provider), window, cx);
}
}
}