Support using ollama as an inline_completion_provider
This commit is contained in:
parent
047d515abf
commit
72d0b2402a
9 changed files with 535 additions and 4 deletions
8
Cargo.lock
generated
8
Cargo.lock
generated
|
@ -10807,10 +10807,17 @@ version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"futures 0.3.31",
|
"futures 0.3.31",
|
||||||
|
"gpui",
|
||||||
"http_client",
|
"http_client",
|
||||||
|
"indoc",
|
||||||
|
"inline_completion",
|
||||||
|
"language",
|
||||||
|
"multi_buffer",
|
||||||
|
"project",
|
||||||
"schemars",
|
"schemars",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
"text",
|
||||||
"workspace-hack",
|
"workspace-hack",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -20005,6 +20012,7 @@ dependencies = [
|
||||||
"nix 0.29.0",
|
"nix 0.29.0",
|
||||||
"node_runtime",
|
"node_runtime",
|
||||||
"notifications",
|
"notifications",
|
||||||
|
"ollama",
|
||||||
"outline",
|
"outline",
|
||||||
"outline_panel",
|
"outline_panel",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
|
|
|
@ -358,6 +358,41 @@ impl Render for InlineCompletionButton {
|
||||||
|
|
||||||
div().child(popover_menu.into_any_element())
|
div().child(popover_menu.into_any_element())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EditPredictionProvider::Ollama => {
|
||||||
|
let enabled = self.editor_enabled.unwrap_or(false);
|
||||||
|
let icon = if enabled {
|
||||||
|
IconName::AiOllama
|
||||||
|
} else {
|
||||||
|
IconName::AiOllama // Could add disabled variant
|
||||||
|
};
|
||||||
|
|
||||||
|
let this = cx.entity().clone();
|
||||||
|
|
||||||
|
div().child(
|
||||||
|
PopoverMenu::new("ollama")
|
||||||
|
.menu(move |window, cx| {
|
||||||
|
Some(
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.build_ollama_context_menu(window, cx)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.trigger(
|
||||||
|
IconButton::new("ollama-completion", icon)
|
||||||
|
.icon_size(IconSize::Small)
|
||||||
|
.tooltip(|window, cx| {
|
||||||
|
Tooltip::for_action(
|
||||||
|
"Ollama Completion",
|
||||||
|
&ToggleMenu,
|
||||||
|
window,
|
||||||
|
cx,
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.with_handle(self.popover_menu_handle.clone()),
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -805,6 +840,26 @@ impl InlineCompletionButton {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_ollama_context_menu(
|
||||||
|
&self,
|
||||||
|
window: &mut Window,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Entity<ContextMenu> {
|
||||||
|
let fs = self.fs.clone();
|
||||||
|
ContextMenu::build(window, cx, |menu, _window, _cx| {
|
||||||
|
menu.entry("Toggle Ollama Completions", None, {
|
||||||
|
let fs = fs.clone();
|
||||||
|
move |_window, cx| {
|
||||||
|
toggle_inline_completions_globally(fs.clone(), cx);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.entry("Ollama Settings...", None, |_window, cx| {
|
||||||
|
// TODO: Open Ollama-specific settings
|
||||||
|
cx.open_url("http://localhost:11434");
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
|
pub fn update_enabled(&mut self, editor: Entity<Editor>, cx: &mut Context<Self>) {
|
||||||
let editor = editor.read(cx);
|
let editor = editor.read(cx);
|
||||||
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
let snapshot = editor.buffer().read(cx).snapshot(cx);
|
||||||
|
|
|
@ -216,6 +216,7 @@ pub enum EditPredictionProvider {
|
||||||
Copilot,
|
Copilot,
|
||||||
Supermaven,
|
Supermaven,
|
||||||
Zed,
|
Zed,
|
||||||
|
Ollama,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EditPredictionProvider {
|
impl EditPredictionProvider {
|
||||||
|
@ -224,7 +225,8 @@ impl EditPredictionProvider {
|
||||||
EditPredictionProvider::Zed => true,
|
EditPredictionProvider::Zed => true,
|
||||||
EditPredictionProvider::None
|
EditPredictionProvider::None
|
||||||
| EditPredictionProvider::Copilot
|
| EditPredictionProvider::Copilot
|
||||||
| EditPredictionProvider::Supermaven => false,
|
| EditPredictionProvider::Supermaven
|
||||||
|
| EditPredictionProvider::Ollama => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,17 +9,34 @@ license = "GPL-3.0-or-later"
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "src/ollama.rs"
|
path = "src/lib.rs"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
schemars = ["dep:schemars"]
|
schemars = ["dep:schemars"]
|
||||||
|
test-support = [
|
||||||
|
"gpui/test-support",
|
||||||
|
"http_client/test-support",
|
||||||
|
"language/test-support",
|
||||||
|
]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow.workspace = true
|
anyhow.workspace = true
|
||||||
futures.workspace = true
|
futures.workspace = true
|
||||||
|
gpui.workspace = true
|
||||||
http_client.workspace = true
|
http_client.workspace = true
|
||||||
|
inline_completion.workspace = true
|
||||||
|
language.workspace = true
|
||||||
|
multi_buffer.workspace = true
|
||||||
|
project.workspace = true
|
||||||
schemars = { workspace = true, optional = true }
|
schemars = { workspace = true, optional = true }
|
||||||
serde.workspace = true
|
serde.workspace = true
|
||||||
serde_json.workspace = true
|
serde_json.workspace = true
|
||||||
|
text.workspace = true
|
||||||
workspace-hack.workspace = true
|
workspace-hack.workspace = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
gpui = { workspace = true, features = ["test-support"] }
|
||||||
|
http_client = { workspace = true, features = ["test-support"] }
|
||||||
|
indoc.workspace = true
|
||||||
|
language = { workspace = true, features = ["test-support"] }
|
||||||
|
|
5
crates/ollama/src/lib.rs
Normal file
5
crates/ollama/src/lib.rs
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
mod ollama;
|
||||||
|
mod ollama_completion_provider;
|
||||||
|
|
||||||
|
pub use ollama::*;
|
||||||
|
pub use ollama_completion_provider::*;
|
|
@ -98,6 +98,38 @@ impl Model {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub prompt: String,
|
||||||
|
pub stream: bool,
|
||||||
|
pub options: Option<GenerateOptions>,
|
||||||
|
pub keep_alive: Option<KeepAlive>,
|
||||||
|
pub context: Option<Vec<i64>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateOptions {
|
||||||
|
pub num_predict: Option<i32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct GenerateResponse {
|
||||||
|
pub response: String,
|
||||||
|
pub done: bool,
|
||||||
|
pub context: Option<Vec<i64>>,
|
||||||
|
pub total_duration: Option<u64>,
|
||||||
|
pub load_duration: Option<u64>,
|
||||||
|
pub prompt_eval_count: Option<i32>,
|
||||||
|
pub eval_count: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize, Debug)]
|
#[derive(Serialize, Deserialize, Debug)]
|
||||||
#[serde(tag = "role", rename_all = "lowercase")]
|
#[serde(tag = "role", rename_all = "lowercase")]
|
||||||
pub enum ChatMessage {
|
pub enum ChatMessage {
|
||||||
|
@ -359,6 +391,36 @@ pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) ->
|
||||||
Ok(details)
|
Ok(details)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn generate(
|
||||||
|
client: &dyn HttpClient,
|
||||||
|
api_url: &str,
|
||||||
|
request: GenerateRequest,
|
||||||
|
) -> Result<GenerateResponse> {
|
||||||
|
let uri = format!("{api_url}/api/generate");
|
||||||
|
let request_builder = HttpRequest::builder()
|
||||||
|
.method(Method::POST)
|
||||||
|
.uri(uri)
|
||||||
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
|
let serialized_request = serde_json::to_string(&request)?;
|
||||||
|
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||||
|
|
||||||
|
let mut response = client.send(request).await?;
|
||||||
|
let mut body = String::new();
|
||||||
|
response.body_mut().read_to_string(&mut body).await?;
|
||||||
|
|
||||||
|
anyhow::ensure!(
|
||||||
|
response.status().is_success(),
|
||||||
|
"Failed to connect to Ollama API: {} {}",
|
||||||
|
response.status(),
|
||||||
|
body,
|
||||||
|
);
|
||||||
|
|
||||||
|
let response: GenerateResponse =
|
||||||
|
serde_json::from_str(&body).context("Unable to parse Ollama generate response")?;
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
363
crates/ollama/src/ollama_completion_provider.rs
Normal file
363
crates/ollama/src/ollama_completion_provider.rs
Normal file
|
@ -0,0 +1,363 @@
|
||||||
|
use crate::{GenerateOptions, GenerateRequest, generate};
|
||||||
|
use anyhow::{Context as AnyhowContext, Result};
|
||||||
|
use gpui::{App, Context, Entity, EntityId, Task};
|
||||||
|
use http_client::HttpClient;
|
||||||
|
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
||||||
|
use language::{Anchor, Buffer, ToOffset};
|
||||||
|
use project::Project;
|
||||||
|
use std::{path::Path, sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
pub const OLLAMA_DEBOUNCE_TIMEOUT: Duration = Duration::from_millis(75);
|
||||||
|
|
||||||
|
pub struct OllamaCompletionProvider {
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
api_url: String,
|
||||||
|
model: String,
|
||||||
|
buffer_id: Option<EntityId>,
|
||||||
|
file_extension: Option<String>,
|
||||||
|
current_completion: Option<String>,
|
||||||
|
pending_refresh: Option<Task<Result<()>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OllamaCompletionProvider {
|
||||||
|
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, model: String) -> Self {
|
||||||
|
Self {
|
||||||
|
http_client,
|
||||||
|
api_url,
|
||||||
|
model,
|
||||||
|
buffer_id: None,
|
||||||
|
file_extension: None,
|
||||||
|
current_completion: None,
|
||||||
|
pending_refresh: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_fim_prompt(&self, prefix: &str, suffix: &str) -> String {
|
||||||
|
// Use model-specific FIM patterns
|
||||||
|
match self.model.as_str() {
|
||||||
|
m if m.contains("codellama") => {
|
||||||
|
format!("<PRE> {prefix} <SUF>{suffix} <MID>")
|
||||||
|
}
|
||||||
|
m if m.contains("deepseek") => {
|
||||||
|
format!("<|fim▁begin|>{prefix}<|fim▁hole|>{suffix}<|fim▁end|>")
|
||||||
|
}
|
||||||
|
m if m.contains("starcoder") => {
|
||||||
|
format!("<fim_prefix>{prefix}<fim_suffix>{suffix}<fim_middle>")
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Generic FIM pattern - fallback for models without specific support
|
||||||
|
format!("// Complete the following code:\n{prefix}\n// COMPLETION HERE\n{suffix}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_context(&self, buffer: &Buffer, cursor_position: Anchor) -> (String, String) {
|
||||||
|
let cursor_offset = cursor_position.to_offset(buffer);
|
||||||
|
let text = buffer.text();
|
||||||
|
|
||||||
|
// Get reasonable context around cursor
|
||||||
|
let context_size = 2000; // 2KB before and after cursor
|
||||||
|
|
||||||
|
let start = cursor_offset.saturating_sub(context_size);
|
||||||
|
let end = (cursor_offset + context_size).min(text.len());
|
||||||
|
|
||||||
|
let prefix = text[start..cursor_offset].to_string();
|
||||||
|
let suffix = text[cursor_offset..end].to_string();
|
||||||
|
|
||||||
|
(prefix, suffix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
|
fn name() -> &'static str {
|
||||||
|
"ollama"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn display_name() -> &'static str {
|
||||||
|
"Ollama"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn show_completions_in_menu() -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_enabled(&self, _buffer: &Entity<Buffer>, _cursor_position: Anchor, _cx: &App) -> bool {
|
||||||
|
// TODO: Could ping Ollama API to check if it's running
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_refreshing(&self) -> bool {
|
||||||
|
self.pending_refresh.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn refresh(
|
||||||
|
&mut self,
|
||||||
|
_project: Option<Entity<Project>>,
|
||||||
|
buffer: Entity<Buffer>,
|
||||||
|
cursor_position: Anchor,
|
||||||
|
debounce: bool,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
let http_client = self.http_client.clone();
|
||||||
|
let api_url = self.api_url.clone();
|
||||||
|
let model = self.model.clone();
|
||||||
|
|
||||||
|
self.pending_refresh = Some(cx.spawn(async move |this, cx| {
|
||||||
|
if debounce {
|
||||||
|
cx.background_executor()
|
||||||
|
.timer(OLLAMA_DEBOUNCE_TIMEOUT)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (prefix, suffix) = this.update(cx, |this, cx| {
|
||||||
|
let buffer_snapshot = buffer.read(cx);
|
||||||
|
this.buffer_id = Some(buffer.entity_id());
|
||||||
|
this.file_extension = buffer_snapshot.file().and_then(|file| {
|
||||||
|
Some(
|
||||||
|
Path::new(file.file_name(cx))
|
||||||
|
.extension()?
|
||||||
|
.to_str()?
|
||||||
|
.to_string(),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
this.extract_context(buffer_snapshot, cursor_position)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let prompt = this.update(cx, |this, _| this.build_fim_prompt(&prefix, &suffix))?;
|
||||||
|
|
||||||
|
let request = GenerateRequest {
|
||||||
|
model: model.clone(),
|
||||||
|
prompt,
|
||||||
|
stream: false,
|
||||||
|
options: Some(GenerateOptions {
|
||||||
|
num_predict: Some(150), // Reasonable completion length
|
||||||
|
temperature: Some(0.1), // Low temperature for more deterministic results
|
||||||
|
top_p: Some(0.95),
|
||||||
|
stop: Some(vec![
|
||||||
|
"\n\n".to_string(),
|
||||||
|
"```".to_string(),
|
||||||
|
"</PRE>".to_string(),
|
||||||
|
"<SUF>".to_string(),
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
keep_alive: None,
|
||||||
|
context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = generate(http_client.as_ref(), &api_url, request)
|
||||||
|
.await
|
||||||
|
.context("Failed to get completion from Ollama")?;
|
||||||
|
|
||||||
|
this.update(cx, |this, cx| {
|
||||||
|
this.pending_refresh = None;
|
||||||
|
if !response.response.trim().is_empty() {
|
||||||
|
this.current_completion = Some(response.response);
|
||||||
|
} else {
|
||||||
|
this.current_completion = None;
|
||||||
|
}
|
||||||
|
cx.notify();
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cycle(
|
||||||
|
&mut self,
|
||||||
|
_buffer: Entity<Buffer>,
|
||||||
|
_cursor_position: Anchor,
|
||||||
|
_direction: Direction,
|
||||||
|
_cx: &mut Context<Self>,
|
||||||
|
) {
|
||||||
|
// Ollama doesn't provide multiple completions in a single request
|
||||||
|
// Could be implemented by making multiple requests with different temperatures
|
||||||
|
// or by using different models
|
||||||
|
}
|
||||||
|
|
||||||
|
fn accept(&mut self, _cx: &mut Context<Self>) {
|
||||||
|
self.current_completion = None;
|
||||||
|
// TODO: Could send accept telemetry to Ollama if supported
|
||||||
|
}
|
||||||
|
|
||||||
|
fn discard(&mut self, _cx: &mut Context<Self>) {
|
||||||
|
self.current_completion = None;
|
||||||
|
// TODO: Could send discard telemetry to Ollama if supported
|
||||||
|
}
|
||||||
|
|
||||||
|
fn suggest(
|
||||||
|
&mut self,
|
||||||
|
buffer: &Entity<Buffer>,
|
||||||
|
cursor_position: Anchor,
|
||||||
|
cx: &mut Context<Self>,
|
||||||
|
) -> Option<InlineCompletion> {
|
||||||
|
let buffer_id = buffer.entity_id();
|
||||||
|
if Some(buffer_id) != self.buffer_id {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let completion_text = self.current_completion.as_ref()?.clone();
|
||||||
|
|
||||||
|
if completion_text.trim().is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let buffer_snapshot = buffer.read(cx);
|
||||||
|
let position = cursor_position.bias_right(buffer_snapshot);
|
||||||
|
|
||||||
|
// Clean up the completion text
|
||||||
|
let completion_text = completion_text.trim_start().trim_end();
|
||||||
|
|
||||||
|
Some(InlineCompletion {
|
||||||
|
id: None,
|
||||||
|
edits: vec![(position..position, completion_text.to_string())],
|
||||||
|
edit_preview: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use gpui::{AppContext, TestAppContext};
|
||||||
|
use http_client::FakeHttpClient;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_patterns(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "function hello() {";
|
||||||
|
let suffix = "}";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<PRE>"));
|
||||||
|
assert!(prompt.contains("<SUF>"));
|
||||||
|
assert!(prompt.contains("<MID>"));
|
||||||
|
assert!(prompt.contains(prefix));
|
||||||
|
assert!(prompt.contains(suffix));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_deepseek_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"deepseek-coder:6.7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<|fim▁begin|>"));
|
||||||
|
assert!(prompt.contains("<|fim▁hole|>"));
|
||||||
|
assert!(prompt.contains("<|fim▁end|>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_fim_prompt_starcoder_pattern(_cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"starcoder:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefix = "def hello():";
|
||||||
|
let suffix = " pass";
|
||||||
|
let prompt = provider.build_fim_prompt(prefix, suffix);
|
||||||
|
|
||||||
|
assert!(prompt.contains("<fim_prefix>"));
|
||||||
|
assert!(prompt.contains("<fim_suffix>"));
|
||||||
|
assert!(prompt.contains("<fim_middle>"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_extract_context(cx: &mut TestAppContext) {
|
||||||
|
let provider = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b".to_string(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Create a simple buffer using test context
|
||||||
|
let buffer_text = "function example() {\n let x = 1;\n let y = 2;\n // cursor here\n return x + y;\n}";
|
||||||
|
let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
|
||||||
|
|
||||||
|
// Position cursor at the end of the "// cursor here" line
|
||||||
|
let (prefix, suffix, _cursor_position) = cx.read(|cx| {
|
||||||
|
let buffer_snapshot = buffer.read(cx);
|
||||||
|
let cursor_position = buffer_snapshot.anchor_after(text::Point::new(3, 15)); // End of "// cursor here"
|
||||||
|
let (prefix, suffix) = provider.extract_context(&buffer_snapshot, cursor_position);
|
||||||
|
(prefix, suffix, cursor_position)
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(prefix.contains("function example()"));
|
||||||
|
assert!(prefix.contains("// cursor h"));
|
||||||
|
assert!(suffix.contains("ere"));
|
||||||
|
assert!(suffix.contains("return x + y"));
|
||||||
|
assert!(suffix.contains("}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_suggest_with_completion(cx: &mut TestAppContext) {
|
||||||
|
let provider = cx.new(|_| {
|
||||||
|
OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b".to_string(),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let buffer_text = "// test";
|
||||||
|
let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
|
||||||
|
|
||||||
|
// Set up a mock completion
|
||||||
|
provider.update(cx, |provider, _| {
|
||||||
|
provider.current_completion = Some("console.log('hello');".to_string());
|
||||||
|
provider.buffer_id = Some(buffer.entity_id());
|
||||||
|
});
|
||||||
|
|
||||||
|
let cursor_position = cx.read(|cx| buffer.read(cx).anchor_after(text::Point::new(0, 7)));
|
||||||
|
|
||||||
|
let completion = provider.update(cx, |provider, cx| {
|
||||||
|
provider.suggest(&buffer, cursor_position, cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(completion.is_some());
|
||||||
|
let completion = completion.unwrap();
|
||||||
|
assert_eq!(completion.edits.len(), 1);
|
||||||
|
assert_eq!(completion.edits[0].1, "console.log('hello');");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_suggest_empty_completion(cx: &mut TestAppContext) {
|
||||||
|
let provider = cx.new(|_| {
|
||||||
|
OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"codellama:7b".to_string(),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let buffer_text = "// test";
|
||||||
|
let buffer = cx.new(|cx| language::Buffer::local(buffer_text, cx));
|
||||||
|
|
||||||
|
// Set up an empty completion
|
||||||
|
provider.update(cx, |provider, _| {
|
||||||
|
provider.current_completion = Some(" ".to_string()); // Only whitespace
|
||||||
|
provider.buffer_id = Some(buffer.entity_id());
|
||||||
|
});
|
||||||
|
|
||||||
|
let cursor_position = cx.read(|cx| buffer.read(cx).anchor_after(text::Point::new(0, 7)));
|
||||||
|
|
||||||
|
let completion = provider.update(cx, |provider, cx| {
|
||||||
|
provider.suggest(&buffer, cursor_position, cx)
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(completion.is_none());
|
||||||
|
}
|
||||||
|
}
|
|
@ -71,6 +71,7 @@ image_viewer.workspace = true
|
||||||
indoc.workspace = true
|
indoc.workspace = true
|
||||||
inline_completion_button.workspace = true
|
inline_completion_button.workspace = true
|
||||||
inspector_ui.workspace = true
|
inspector_ui.workspace = true
|
||||||
|
ollama.workspace = true
|
||||||
install_cli.workspace = true
|
install_cli.workspace = true
|
||||||
jj_ui.workspace = true
|
jj_ui.workspace = true
|
||||||
journal.workspace = true
|
journal.workspace = true
|
||||||
|
|
|
@ -4,7 +4,9 @@ use copilot::{Copilot, CopilotCompletionProvider};
|
||||||
use editor::Editor;
|
use editor::Editor;
|
||||||
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
use gpui::{AnyWindowHandle, App, AppContext as _, Context, Entity, WeakEntity};
|
||||||
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
use language::language_settings::{EditPredictionProvider, all_language_settings};
|
||||||
use settings::SettingsStore;
|
use language_models::AllLanguageModelSettings;
|
||||||
|
use ollama::OllamaCompletionProvider;
|
||||||
|
use settings::{Settings, SettingsStore};
|
||||||
use smol::stream::StreamExt;
|
use smol::stream::StreamExt;
|
||||||
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
use std::{cell::RefCell, rc::Rc, sync::Arc};
|
||||||
use supermaven::{Supermaven, SupermavenCompletionProvider};
|
use supermaven::{Supermaven, SupermavenCompletionProvider};
|
||||||
|
@ -129,7 +131,8 @@ pub fn init(client: Arc<Client>, user_store: Entity<UserStore>, cx: &mut App) {
|
||||||
}
|
}
|
||||||
EditPredictionProvider::None
|
EditPredictionProvider::None
|
||||||
| EditPredictionProvider::Copilot
|
| EditPredictionProvider::Copilot
|
||||||
| EditPredictionProvider::Supermaven => {}
|
| EditPredictionProvider::Supermaven
|
||||||
|
| EditPredictionProvider::Ollama => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -283,5 +286,20 @@ fn assign_edit_prediction_provider(
|
||||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
EditPredictionProvider::Ollama => {
|
||||||
|
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
|
||||||
|
let api_url = settings.api_url.clone();
|
||||||
|
|
||||||
|
// Use first available model or default
|
||||||
|
let model = settings
|
||||||
|
.available_models
|
||||||
|
.first()
|
||||||
|
.map(|m| m.name.clone())
|
||||||
|
.unwrap_or_else(|| "codellama:7b".to_string());
|
||||||
|
|
||||||
|
let provider =
|
||||||
|
cx.new(|_| OllamaCompletionProvider::new(client.http_client(), api_url, model));
|
||||||
|
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue