ZIm/crates/language_model/src/provider/ollama.rs
Thorsten Ball aee01f2c50
assistant: Remove low_speed_timeout (#20681)
This removes the `low_speed_timeout` setting from all providers as a
response to issue #19509.

Reason being that the original `low_speed_timeout` was only as part of
#9913 because users wanted to _get rid of timeouts_. They wanted to bump
the default timeout from 5sec to a lot more.

Then, in the meantime, the meaning of `low_speed_timeout` changed in
#19055 and was changed to a normal `timeout`, which is a different thing
and breaks slower LLMs that don't reply with a complete response in the
configured timeout.

So we figured: let's remove the whole thing and replace it with a
default _connect_ timeout to make sure that we can connect to a server
in 10s, but then give the server as long as it wants to complete its
response.

Closes #19509

Release Notes:

- Removed the `low_speed_timeout` setting from LLM provider settings,
since it was only used to _increase_ the timeout to give LLMs more time,
but since we don't have any other use for it, we simply remove the
setting to give LLMs as long as they need.

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Peter Tripp <peter@zed.dev>
2024-11-15 07:37:31 +01:00

553 lines
21 KiB
Rust

use anyhow::{anyhow, bail, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, KeepAlive, OllamaToolCall,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{collections::BTreeMap, sync::Arc};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
use crate::LanguageModelCompletionEvent;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
const OLLAMA_SITE: &str = "https://ollama.com/";
const PROVIDER_ID: &str = "ollama";
const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
pub api_url: String,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
/// The model name in the Ollama API (e.g. "llama3.1:latest")
pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The Context Length parameter to the model (aka num_ctx or n_ctx)
pub max_tokens: usize,
/// The number of seconds to keep the connection open after the last request
pub keep_alive: Option<KeepAlive>,
}
pub struct OllamaLanguageModelProvider {
http_client: Arc<dyn HttpClient>,
state: gpui::Model<State>,
}
pub struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
fetch_model_task: Option<Task<Result<()>>>,
_subscription: Subscription,
}
impl State {
fn is_authenticated(&self) -> bool {
!self.available_models.is_empty()
}
fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
let models = get_models(http_client.as_ref(), &api_url, None).await?;
let mut models: Vec<ollama::Model> = models
.into_iter()
// Since there is no metadata from the Ollama API
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name, None, None))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
this.update(&mut cx, |this, cx| {
this.available_models = models;
cx.notify();
})
})
}
fn restart_fetch_models_task(&mut self, cx: &mut ModelContext<Self>) {
let task = self.fetch_models(cx);
self.fetch_model_task.replace(task);
}
fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
if self.is_authenticated() {
Task::ready(Ok(()))
} else {
self.fetch_models(cx)
}
}
}
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let this = Self {
http_client: http_client.clone(),
state: cx.new_model(|cx| {
let subscription = cx.observe_global::<SettingsStore>({
let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
move |this: &mut State, cx| {
let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
if &settings != new_settings {
settings = new_settings.clone();
this.restart_fetch_models_task(cx);
cx.notify();
}
}
});
State {
http_client,
available_models: Default::default(),
fetch_model_task: None,
_subscription: subscription,
}
}),
};
this.state
.update(cx, |state, cx| state.restart_fetch_models_task(cx));
this
}
}
impl LanguageModelProviderState for OllamaLanguageModelProvider {
type ObservableEntity = State;
fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
Some(self.state.clone())
}
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
fn id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn icon(&self) -> IconName {
IconName::AiOllama
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
// Add models from the Ollama API
for model in self.state.read(cx).available_models.iter() {
models.insert(model.name.clone(), model.clone());
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.ollama
.available_models
.iter()
{
models.insert(
model.name.clone(),
ollama::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: model.keep_alive.clone(),
},
);
}
models
.into_values()
.map(|model| {
Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.collect()
}
fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
let api_url = settings.api_url.clone();
let id = model.id().0.to_string();
cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
.detach_and_log_err(cx);
}
fn is_authenticated(&self, cx: &AppContext) -> bool {
self.state.read(cx).is_authenticated()
}
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.authenticate(cx))
}
fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
let state = self.state.clone();
cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
}
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
self.state.update(cx, |state, cx| state.fetch_models(cx))
}
}
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
impl OllamaLanguageModel {
fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
ChatRequest {
model: self.model.name.clone(),
messages: request
.messages
.into_iter()
.map(|msg| match msg.role {
Role::User => ChatMessage::User {
content: msg.string_contents(),
},
Role::Assistant => ChatMessage::Assistant {
content: msg.string_contents(),
tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.string_contents(),
},
})
.collect(),
keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
stream: true,
options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens),
stop: Some(request.stop),
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
tools: vec![],
}
}
fn request_completion(
&self,
request: ChatRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<ChatResponseDelta>> {
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
}
}
impl LanguageModel for OllamaLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
fn provider_id(&self) -> LanguageModelProviderId {
LanguageModelProviderId(PROVIDER_ID.into())
}
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn telemetry_id(&self) -> String {
format!("ollama/{}", self.model.id())
}
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
fn count_tokens(
&self,
request: LanguageModelRequest,
_cx: &AppContext,
) -> BoxFuture<'static, Result<usize>> {
// There is no endpoint for this _yet_ in Ollama
// see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
let token_count = request
.messages
.iter()
.map(|msg| msg.string_contents().chars().count())
.sum::<usize>()
/ 4;
async move { Ok(token_count) }.boxed()
}
fn stream_completion(
&self,
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
let Ok(api_url) = cx.update(|cx| {
let settings = &AllLanguageModelSettings::get_global(cx).ollama;
settings.api_url.clone()
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
let future = self.request_limiter.stream(async move {
let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
let stream = response
.filter_map(|response| async move {
match response {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
ChatMessage::Assistant { content, .. } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
}
Err(error) => Some(Err(error)),
}
})
.boxed();
Ok(stream)
});
async move {
Ok(future
.await?
.map(|result| result.map(LanguageModelCompletionEvent::Text))
.boxed())
}
.boxed()
}
fn use_any_tool(
&self,
request: LanguageModelRequest,
tool_name: String,
tool_description: String,
schema: serde_json::Value,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
use ollama::{OllamaFunctionTool, OllamaTool};
let function = OllamaFunctionTool {
name: tool_name.clone(),
description: Some(tool_description),
parameters: Some(schema),
};
let tools = vec![OllamaTool::Function { function }];
let request = self.to_ollama_request(request).with_tools(tools);
let response = self.request_completion(request, cx);
self.request_limiter
.run(async move {
let response = response.await?;
let ChatMessage::Assistant { tool_calls, .. } = response.message else {
bail!("message does not have an assistant role");
};
if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
for call in tool_calls {
let OllamaToolCall::Function(function) = call;
if function.name == tool_name {
return Ok(futures::stream::once(async move {
Ok(function.arguments.to_string())
})
.boxed());
}
}
} else {
bail!("assistant message does not have any tool calls");
};
bail!("tool not used")
})
.boxed()
}
}
struct ConfigurationView {
state: gpui::Model<State>,
loading_models_task: Option<Task<()>>,
}
impl ConfigurationView {
pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
let loading_models_task = Some(cx.spawn({
let state = state.clone();
|this, mut cx| async move {
if let Some(task) = state
.update(&mut cx, |state, cx| state.authenticate(cx))
.log_err()
{
task.await.log_err();
}
this.update(&mut cx, |this, cx| {
this.loading_models_task = None;
cx.notify();
})
.log_err();
}
}));
Self {
state,
loading_models_task,
}
}
fn retry_connection(&self, cx: &mut WindowContext) {
self.state
.update(cx, |state, cx| state.fetch_models(cx))
.detach_and_log_err(cx);
}
}
impl Render for ConfigurationView {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
let is_authenticated = self.state.read(cx).is_authenticated();
let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
let ollama_reqs =
"Ollama must be running with at least one model installed to use it in the assistant.";
let mut inline_code_bg = cx.theme().colors().editor_background;
inline_code_bg.fade_out(0.5);
if self.loading_models_task.is_some() {
div().child(Label::new("Loading models...")).into_any()
} else {
v_flex()
.size_full()
.gap_3()
.child(
v_flex()
.size_full()
.gap_2()
.p_1()
.child(Label::new(ollama_intro))
.child(Label::new(ollama_reqs))
.child(
h_flex()
.gap_0p5()
.child(Label::new("Once installed, try "))
.child(
div()
.bg(inline_code_bg)
.px_1p5()
.rounded_md()
.child(Label::new("ollama run llama3.1")),
),
),
)
.child(
h_flex()
.w_full()
.pt_2()
.justify_between()
.gap_2()
.child(
h_flex()
.w_full()
.gap_2()
.map(|this| {
if is_authenticated {
this.child(
Button::new("ollama-site", "Ollama")
.style(ButtonStyle::Subtle)
.icon(IconName::ExternalLink)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
.into_any_element(),
)
} else {
this.child(
Button::new(
"download_ollama_button",
"Download Ollama",
)
.style(ButtonStyle::Subtle)
.icon(IconName::ExternalLink)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
.into_any_element(),
)
}
})
.child(
Button::new("view-models", "All Models")
.style(ButtonStyle::Subtle)
.icon(IconName::ExternalLink)
.icon_size(IconSize::XSmall)
.icon_color(Color::Muted)
.on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
),
)
.child(if is_authenticated {
// This is only a button to ensure the spacing is correct
// it should stay disabled
ButtonLike::new("connected")
.disabled(true)
// Since this won't ever be clickable, we can use the arrow cursor
.cursor_style(gpui::CursorStyle::Arrow)
.child(
h_flex()
.gap_2()
.child(Indicator::dot().color(Color::Success))
.child(Label::new("Connected"))
.into_any_element(),
)
.into_any_element()
} else {
Button::new("retry_ollama_models", "Connect")
.icon_position(IconPosition::Start)
.icon(IconName::ArrowCircle)
.on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
.into_any_element()
}),
)
.into_any()
}
}
}