Ollama Provider for Assistant (#12902)
Closes #4424. A few design decisions that may need some rethinking or later PRs: * Other providers have a check for authentication. I use this opportunity to fetch the models which doubles as a way of finding out if the Ollama server is running. * Ollama has _no_ API for getting the max tokens per model * Ollama has _no_ API for getting the current token count https://github.com/ollama/ollama/issues/1716 * Ollama does allow setting the `num_ctx` so I've defaulted this to 4096. It can be overridden in settings. * Ollama models will be "slow" to start inference because they're loading the model into memory. It's faster after that. There's no UI affordance to show that the model is being loaded. Release Notes: - Added an Ollama Provider for the assistant. If you have [Ollama](https://ollama.com/) running locally on your machine, you can enable it in your settings under: ```jsonc "assistant": { "version": "1", "provider": { "name": "ollama", // Recommended setting to allow for model startup "low_speed_timeout_in_seconds": 30, } } ``` Chat like usual <img width="1840" alt="image" src="https://github.com/zed-industries/zed/assets/836375/4e0af266-4c4f-4d9e-9d74-1a91f76a12fe"> Interact with any model from the [Ollama Library](https://ollama.com/library) <img width="587" alt="image" src="https://github.com/zed-industries/zed/assets/836375/87433ac6-bf87-4a99-89e1-96a93bf8de8a"> Open up the terminal to download new models via `ollama pull`: 
This commit is contained in:
parent
127b9ed857
commit
4cb8d6f40e
9 changed files with 624 additions and 1 deletions
22
crates/ollama/Cargo.toml
Normal file
22
crates/ollama/Cargo.toml
Normal file
|
@ -0,0 +1,22 @@
|
|||
[package]
|
||||
name = "ollama"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lib]
|
||||
path = "src/ollama.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
schemars = ["dep:schemars"]
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http.workspace = true
|
||||
isahc.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
224
crates/ollama/src/ollama.rs
Normal file
224
crates/ollama/src/ollama.rs
Normal file
|
@ -0,0 +1,224 @@
|
|||
use anyhow::{anyhow, Context, Result};
|
||||
use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, StreamExt};
|
||||
use http::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use isahc::config::Configurable;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::TryFrom, time::Duration};
|
||||
|
||||
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
|
||||
|
||||
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Assistant,
|
||||
System,
|
||||
}
|
||||
|
||||
impl TryFrom<String> for Role {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: String) -> Result<Self> {
|
||||
match value.as_str() {
|
||||
"user" => Ok(Self::User),
|
||||
"assistant" => Ok(Self::Assistant),
|
||||
"system" => Ok(Self::System),
|
||||
_ => Err(anyhow!("invalid role '{value}'")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Role> for String {
|
||||
fn from(val: Role) -> Self {
|
||||
match val {
|
||||
Role::User => "user".to_owned(),
|
||||
Role::Assistant => "assistant".to_owned(),
|
||||
Role::System => "system".to_owned(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Model {
|
||||
pub name: String,
|
||||
pub parameter_size: String,
|
||||
pub max_tokens: usize,
|
||||
pub keep_alive: Option<String>,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(name: &str, parameter_size: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_owned(),
|
||||
parameter_size: parameter_size.to_owned(),
|
||||
// todo: determine if there's an endpoint to find the max tokens
|
||||
// I'm not seeing it in the API docs but it's on the model cards
|
||||
max_tokens: 2048,
|
||||
keep_alive: Some("10m".to_owned()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn max_token_count(&self) -> usize {
|
||||
self.max_tokens
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
|
||||
#[serde(tag = "role", rename_all = "lowercase")]
|
||||
pub enum ChatMessage {
|
||||
Assistant { content: String },
|
||||
User { content: String },
|
||||
System { content: String },
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct ChatRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
pub keep_alive: Option<String>,
|
||||
pub options: Option<ChatOptions>,
|
||||
}
|
||||
|
||||
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct ChatOptions {
|
||||
pub num_ctx: Option<usize>,
|
||||
pub num_predict: Option<isize>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatResponseDelta {
|
||||
#[allow(unused)]
|
||||
pub model: String,
|
||||
#[allow(unused)]
|
||||
pub created_at: String,
|
||||
pub message: ChatMessage,
|
||||
#[allow(unused)]
|
||||
pub done_reason: Option<String>,
|
||||
#[allow(unused)]
|
||||
pub done: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct LocalModelsResponse {
|
||||
pub models: Vec<LocalModelListing>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct LocalModelListing {
|
||||
pub name: String,
|
||||
pub modified_at: String,
|
||||
pub size: u64,
|
||||
pub digest: String,
|
||||
pub details: ModelDetails,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct LocalModel {
|
||||
pub modelfile: String,
|
||||
pub parameters: String,
|
||||
pub template: String,
|
||||
pub details: ModelDetails,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct ModelDetails {
|
||||
pub format: String,
|
||||
pub family: String,
|
||||
pub families: Option<Vec<String>>,
|
||||
pub parameter_size: String,
|
||||
pub quantization_level: String,
|
||||
}
|
||||
|
||||
pub async fn stream_chat_completion(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
request: ChatRequest,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<BoxStream<'static, Result<ChatResponseDelta>>> {
|
||||
let uri = format!("{api_url}/api/chat");
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
};
|
||||
|
||||
let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(request).await?;
|
||||
if response.status().is_success() {
|
||||
let reader = BufReader::new(response.into_body());
|
||||
|
||||
Ok(reader
|
||||
.lines()
|
||||
.filter_map(|line| async move {
|
||||
match line {
|
||||
Ok(line) => {
|
||||
Some(serde_json::from_str(&line).context("Unable to parse chat response"))
|
||||
}
|
||||
Err(e) => Some(Err(e.into())),
|
||||
}
|
||||
})
|
||||
.boxed())
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
Err(anyhow!(
|
||||
"Failed to connect to Ollama API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_models(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
low_speed_timeout: Option<Duration>,
|
||||
) -> Result<Vec<LocalModelListing>> {
|
||||
let uri = format!("{api_url}/api/tags");
|
||||
let mut request_builder = HttpRequest::builder()
|
||||
.method(Method::GET)
|
||||
.uri(uri)
|
||||
.header("Accept", "application/json");
|
||||
|
||||
if let Some(low_speed_timeout) = low_speed_timeout {
|
||||
request_builder = request_builder.low_speed_timeout(100, low_speed_timeout);
|
||||
};
|
||||
|
||||
let request = request_builder.body(AsyncBody::default())?;
|
||||
|
||||
let mut response = client.send(request).await?;
|
||||
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let response: LocalModelsResponse =
|
||||
serde_json::from_str(&body).context("Unable to parse Ollama tag listing")?;
|
||||
|
||||
Ok(response.models)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Failed to connect to Ollama API: {} {}",
|
||||
response.status(),
|
||||
body,
|
||||
))
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue