Support using an API key

This commit is contained in:
Oliver Azevedo Barnes 2025-07-08 22:09:50 +01:00
parent cce9949d92
commit 9188e3f5de
No known key found for this signature in database
3 changed files with 107 additions and 6 deletions

View file

@ -1,9 +1,11 @@
use crate::{GenerateOptions, GenerateRequest, generate};
use anyhow::{Context as AnyhowContext, Result};
use gpui::{App, Context, Entity, EntityId, Task};
use http_client::HttpClient;
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
use language::{Anchor, Buffer, ToOffset};
use project::Project;
use std::{path::Path, sync::Arc, time::Duration};
@ -17,10 +19,16 @@ pub struct OllamaCompletionProvider {
file_extension: Option<String>,
current_completion: Option<String>,
pending_refresh: Option<Task<Result<()>>>,
api_key: Option<String>,
}
impl OllamaCompletionProvider {
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, model: String) -> Self {
pub fn new(
http_client: Arc<dyn HttpClient>,
api_url: String,
model: String,
api_key: Option<String>,
) -> Self {
Self {
http_client,
api_url,
@ -29,6 +37,7 @@ impl OllamaCompletionProvider {
file_extension: None,
current_completion: None,
pending_refresh: None,
api_key,
}
}
@ -114,7 +123,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
this.extract_context(buffer_snapshot, cursor_position)
})?;
let model = this.update(cx, |this, _| this.model.clone())?;
let (model, api_key) =
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
let request = GenerateRequest {
model,
@ -131,7 +141,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
context: None,
};
let response = generate(http_client.as_ref(), &api_url, request)
let response = generate(http_client.as_ref(), &api_url, api_key, request)
.await
.context("Failed to get completion from Ollama")?;
@ -216,6 +226,7 @@ mod tests {
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
None,
);
// Create a simple buffer using test context
@ -244,6 +255,7 @@ mod tests {
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
None,
)
});
@ -275,6 +287,7 @@ mod tests {
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
None,
)
});
@ -302,6 +315,7 @@ mod tests {
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"codellama:7b".to_string(),
None,
);
// Verify initial model
@ -332,6 +346,7 @@ mod tests {
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
None,
);
let prefix = "def fibonacci(n):";
@ -362,4 +377,50 @@ mod tests {
// Verify stop tokens are handled natively by Ollama
assert!(request.options.as_ref().unwrap().stop.is_none());
}
#[gpui::test]
async fn test_api_key_support(_cx: &mut TestAppContext) {
// Test with API key
let provider_with_key = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
Some("test-api-key".to_string()),
);
// Test without API key
let provider_without_key = OllamaCompletionProvider::new(
Arc::new(FakeHttpClient::with_404_response()),
"http://localhost:11434".to_string(),
"qwen2.5-coder:32b".to_string(),
None,
);
// Verify API key is stored correctly
assert_eq!(provider_with_key.api_key, Some("test-api-key".to_string()));
assert_eq!(provider_without_key.api_key, None);
// Verify API key is passed to generate request
let prefix = "def test():";
let suffix = " pass";
let request_with_key = GenerateRequest {
model: provider_with_key.model.clone(),
prompt: prefix.to_string(),
suffix: Some(suffix.to_string()),
stream: false,
options: Some(GenerateOptions {
num_predict: Some(150),
temperature: Some(0.1),
top_p: Some(0.95),
stop: None,
}),
keep_alive: None,
context: None,
};
// The actual API key usage would be tested in the generate function
// but we can verify the provider stores it correctly
assert_eq!(provider_with_key.api_key, Some("test-api-key".to_string()));
}
}