Support using an API key
This commit is contained in:
parent
cce9949d92
commit
9188e3f5de
3 changed files with 107 additions and 6 deletions
|
@ -395,14 +395,19 @@ pub async fn show_model(client: &dyn HttpClient, api_url: &str, model: &str) ->
|
||||||
pub async fn generate(
|
pub async fn generate(
|
||||||
client: &dyn HttpClient,
|
client: &dyn HttpClient,
|
||||||
api_url: &str,
|
api_url: &str,
|
||||||
|
api_key: Option<String>,
|
||||||
request: GenerateRequest,
|
request: GenerateRequest,
|
||||||
) -> Result<GenerateResponse> {
|
) -> Result<GenerateResponse> {
|
||||||
let uri = format!("{api_url}/api/generate");
|
let uri = format!("{api_url}/api/generate");
|
||||||
let request_builder = HttpRequest::builder()
|
let mut request_builder = HttpRequest::builder()
|
||||||
.method(Method::POST)
|
.method(Method::POST)
|
||||||
.uri(uri)
|
.uri(uri)
|
||||||
.header("Content-Type", "application/json");
|
.header("Content-Type", "application/json");
|
||||||
|
|
||||||
|
if let Some(api_key) = api_key {
|
||||||
|
request_builder = request_builder.header("Authorization", format!("Bearer {api_key}"))
|
||||||
|
}
|
||||||
|
|
||||||
let serialized_request = serde_json::to_string(&request)?;
|
let serialized_request = serde_json::to_string(&request)?;
|
||||||
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
let request = request_builder.body(AsyncBody::from(serialized_request))?;
|
||||||
|
|
||||||
|
@ -674,4 +679,35 @@ mod tests {
|
||||||
assert_eq!(message_images.len(), 1);
|
assert_eq!(message_images.len(), 1);
|
||||||
assert_eq!(message_images[0].as_str().unwrap(), base64_image);
|
assert_eq!(message_images[0].as_str().unwrap(), base64_image);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_generate_request_with_api_key_serialization() {
|
||||||
|
let request = GenerateRequest {
|
||||||
|
model: "qwen2.5-coder:32b".to_string(),
|
||||||
|
prompt: "def fibonacci(n):".to_string(),
|
||||||
|
suffix: Some(" return result".to_string()),
|
||||||
|
stream: false,
|
||||||
|
options: Some(GenerateOptions {
|
||||||
|
num_predict: Some(150),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
stop: None,
|
||||||
|
}),
|
||||||
|
keep_alive: None,
|
||||||
|
context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test with API key
|
||||||
|
let json = serde_json::to_string(&request).unwrap();
|
||||||
|
let parsed: GenerateRequest = serde_json::from_str(&json).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.model, "qwen2.5-coder:32b");
|
||||||
|
assert_eq!(parsed.prompt, "def fibonacci(n):");
|
||||||
|
assert_eq!(parsed.suffix, Some(" return result".to_string()));
|
||||||
|
assert!(!parsed.stream);
|
||||||
|
assert!(parsed.options.is_some());
|
||||||
|
|
||||||
|
// Note: The API key parameter is passed to the generate function itself,
|
||||||
|
// not included in the GenerateRequest struct that gets serialized to JSON
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
use crate::{GenerateOptions, GenerateRequest, generate};
|
use crate::{GenerateOptions, GenerateRequest, generate};
|
||||||
use anyhow::{Context as AnyhowContext, Result};
|
use anyhow::{Context as AnyhowContext, Result};
|
||||||
|
|
||||||
use gpui::{App, Context, Entity, EntityId, Task};
|
use gpui::{App, Context, Entity, EntityId, Task};
|
||||||
use http_client::HttpClient;
|
use http_client::HttpClient;
|
||||||
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
use inline_completion::{Direction, EditPredictionProvider, InlineCompletion};
|
||||||
use language::{Anchor, Buffer, ToOffset};
|
use language::{Anchor, Buffer, ToOffset};
|
||||||
|
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use std::{path::Path, sync::Arc, time::Duration};
|
use std::{path::Path, sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
@ -17,10 +19,16 @@ pub struct OllamaCompletionProvider {
|
||||||
file_extension: Option<String>,
|
file_extension: Option<String>,
|
||||||
current_completion: Option<String>,
|
current_completion: Option<String>,
|
||||||
pending_refresh: Option<Task<Result<()>>>,
|
pending_refresh: Option<Task<Result<()>>>,
|
||||||
|
api_key: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaCompletionProvider {
|
impl OllamaCompletionProvider {
|
||||||
pub fn new(http_client: Arc<dyn HttpClient>, api_url: String, model: String) -> Self {
|
pub fn new(
|
||||||
|
http_client: Arc<dyn HttpClient>,
|
||||||
|
api_url: String,
|
||||||
|
model: String,
|
||||||
|
api_key: Option<String>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
http_client,
|
http_client,
|
||||||
api_url,
|
api_url,
|
||||||
|
@ -29,6 +37,7 @@ impl OllamaCompletionProvider {
|
||||||
file_extension: None,
|
file_extension: None,
|
||||||
current_completion: None,
|
current_completion: None,
|
||||||
pending_refresh: None,
|
pending_refresh: None,
|
||||||
|
api_key,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,7 +123,8 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
this.extract_context(buffer_snapshot, cursor_position)
|
this.extract_context(buffer_snapshot, cursor_position)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let model = this.update(cx, |this, _| this.model.clone())?;
|
let (model, api_key) =
|
||||||
|
this.update(cx, |this, _| (this.model.clone(), this.api_key.clone()))?;
|
||||||
|
|
||||||
let request = GenerateRequest {
|
let request = GenerateRequest {
|
||||||
model,
|
model,
|
||||||
|
@ -131,7 +141,7 @@ impl EditPredictionProvider for OllamaCompletionProvider {
|
||||||
context: None,
|
context: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = generate(http_client.as_ref(), &api_url, request)
|
let response = generate(http_client.as_ref(), &api_url, api_key, request)
|
||||||
.await
|
.await
|
||||||
.context("Failed to get completion from Ollama")?;
|
.context("Failed to get completion from Ollama")?;
|
||||||
|
|
||||||
|
@ -216,6 +226,7 @@ mod tests {
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"codellama:7b".to_string(),
|
"codellama:7b".to_string(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create a simple buffer using test context
|
// Create a simple buffer using test context
|
||||||
|
@ -244,6 +255,7 @@ mod tests {
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"codellama:7b".to_string(),
|
"codellama:7b".to_string(),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -275,6 +287,7 @@ mod tests {
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"codellama:7b".to_string(),
|
"codellama:7b".to_string(),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -302,6 +315,7 @@ mod tests {
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"codellama:7b".to_string(),
|
"codellama:7b".to_string(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
// Verify initial model
|
// Verify initial model
|
||||||
|
@ -332,6 +346,7 @@ mod tests {
|
||||||
Arc::new(FakeHttpClient::with_404_response()),
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
"http://localhost:11434".to_string(),
|
"http://localhost:11434".to_string(),
|
||||||
"qwen2.5-coder:32b".to_string(),
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
let prefix = "def fibonacci(n):";
|
let prefix = "def fibonacci(n):";
|
||||||
|
@ -362,4 +377,50 @@ mod tests {
|
||||||
// Verify stop tokens are handled natively by Ollama
|
// Verify stop tokens are handled natively by Ollama
|
||||||
assert!(request.options.as_ref().unwrap().stop.is_none());
|
assert!(request.options.as_ref().unwrap().stop.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[gpui::test]
|
||||||
|
async fn test_api_key_support(_cx: &mut TestAppContext) {
|
||||||
|
// Test with API key
|
||||||
|
let provider_with_key = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
Some("test-api-key".to_string()),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Test without API key
|
||||||
|
let provider_without_key = OllamaCompletionProvider::new(
|
||||||
|
Arc::new(FakeHttpClient::with_404_response()),
|
||||||
|
"http://localhost:11434".to_string(),
|
||||||
|
"qwen2.5-coder:32b".to_string(),
|
||||||
|
None,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify API key is stored correctly
|
||||||
|
assert_eq!(provider_with_key.api_key, Some("test-api-key".to_string()));
|
||||||
|
assert_eq!(provider_without_key.api_key, None);
|
||||||
|
|
||||||
|
// Verify API key is passed to generate request
|
||||||
|
let prefix = "def test():";
|
||||||
|
let suffix = " pass";
|
||||||
|
|
||||||
|
let request_with_key = GenerateRequest {
|
||||||
|
model: provider_with_key.model.clone(),
|
||||||
|
prompt: prefix.to_string(),
|
||||||
|
suffix: Some(suffix.to_string()),
|
||||||
|
stream: false,
|
||||||
|
options: Some(GenerateOptions {
|
||||||
|
num_predict: Some(150),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
top_p: Some(0.95),
|
||||||
|
stop: None,
|
||||||
|
}),
|
||||||
|
keep_alive: None,
|
||||||
|
context: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
// The actual API key usage would be tested in the generate function
|
||||||
|
// but we can verify the provider stores it correctly
|
||||||
|
assert_eq!(provider_with_key.api_key, Some("test-api-key".to_string()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -342,8 +342,12 @@ fn assign_edit_prediction_provider(
|
||||||
.map(|m| m.name.clone())
|
.map(|m| m.name.clone())
|
||||||
.unwrap_or_else(|| "codellama:7b".to_string());
|
.unwrap_or_else(|| "codellama:7b".to_string());
|
||||||
|
|
||||||
let provider =
|
// Get API key from environment variable only (credentials would require async handling)
|
||||||
cx.new(|_| OllamaCompletionProvider::new(client.http_client(), api_url, model));
|
let api_key = std::env::var("OLLAMA_API_KEY").ok();
|
||||||
|
|
||||||
|
let provider = cx.new(|_| {
|
||||||
|
OllamaCompletionProvider::new(client.http_client(), api_url, model, api_key)
|
||||||
|
});
|
||||||
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
editor.set_edit_prediction_provider(Some(provider), window, cx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue