Ollama max_tokens settings (#17025)

- Support `available_models` for Ollama
- Clamp default max tokens (context length) to 16384.
- Add documentation for ollama context configuration.
This commit is contained in:
Peter Tripp 2024-08-30 12:52:00 +00:00 committed by GitHub
parent d401ab1efc
commit b62e63349b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 92 additions and 35 deletions

View file

@ -135,6 +135,7 @@ impl AssistantSettingsContent {
Some(language_model::settings::OllamaSettingsContent { Some(language_model::settings::OllamaSettingsContent {
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
available_models: None,
}); });
} }
}, },
@ -295,7 +296,7 @@ impl AssistantSettingsContent {
_ => (None, None), _ => (None, None),
}; };
settings.provider = Some(AssistantProviderContentV1::Ollama { settings.provider = Some(AssistantProviderContentV1::Ollama {
default_model: Some(ollama::Model::new(&model)), default_model: Some(ollama::Model::new(&model, None, None)),
api_url, api_url,
low_speed_timeout_in_seconds, low_speed_timeout_in_seconds,
}); });

View file

@ -6,8 +6,10 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall, ChatResponseDelta, OllamaToolCall,
}; };
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore}; use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration}; use std::{collections::BTreeMap, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator}; use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt; use util::ResultExt;
@ -28,6 +30,17 @@ const PROVIDER_NAME: &str = "Ollama";
pub struct OllamaSettings { pub struct OllamaSettings {
pub api_url: String, pub api_url: String,
pub low_speed_timeout: Option<Duration>, pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
/// The model name in the Ollama API (e.g. "llama3.1:latest")
pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The Context Length parameter to the model (aka num_ctx or n_ctx)
pub max_tokens: usize,
} }
pub struct OllamaLanguageModelProvider { pub struct OllamaLanguageModelProvider {
@ -61,7 +74,7 @@ impl State {
// indicating which models are embedding models, // indicating which models are embedding models,
// simply filter out models with "-embed" in their name // simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed")) .filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name)) .map(|model| ollama::Model::new(&model.name, None, None))
.collect(); .collect();
models.sort_by(|a, b| a.name.cmp(&b.name)); models.sort_by(|a, b| a.name.cmp(&b.name));
@ -123,10 +136,32 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
} }
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> { fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
self.state let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
.read(cx)
// Add models from the Ollama API
for model in self.state.read(cx).available_models.iter() {
models.insert(model.name.clone(), model.clone());
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.ollama
.available_models .available_models
.iter() .iter()
{
models.insert(
model.name.clone(),
ollama::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: None,
},
);
}
models
.into_values()
.map(|model| { .map(|model| {
Arc::new(OllamaLanguageModel { Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()), id: LanguageModelId::from(model.name.clone()),

View file

@ -152,6 +152,7 @@ pub struct AnthropicSettingsContentV1 {
pub struct OllamaSettingsContent { pub struct OllamaSettingsContent {
pub api_url: Option<String>, pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>, pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::ollama::AvailableModel>>,
} }
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)] #[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -276,6 +277,9 @@ impl settings::Settings for AllLanguageModelSettings {
anthropic.as_ref().and_then(|s| s.available_models.clone()), anthropic.as_ref().and_then(|s| s.available_models.clone()),
); );
// Ollama
let ollama = value.ollama.clone();
merge( merge(
&mut settings.ollama.api_url, &mut settings.ollama.api_url,
value.ollama.as_ref().and_then(|s| s.api_url.clone()), value.ollama.as_ref().and_then(|s| s.api_url.clone()),
@ -288,6 +292,10 @@ impl settings::Settings for AllLanguageModelSettings {
settings.ollama.low_speed_timeout = settings.ollama.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds)); Some(Duration::from_secs(low_speed_timeout_in_seconds));
} }
merge(
&mut settings.ollama.available_models,
ollama.as_ref().and_then(|s| s.available_models.clone()),
);
// OpenAI // OpenAI
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) { let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {

View file

@ -66,40 +66,37 @@ impl Default for KeepAlive {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)] #[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Model { pub struct Model {
pub name: String, pub name: String,
pub display_name: Option<String>,
pub max_tokens: usize, pub max_tokens: usize,
pub keep_alive: Option<KeepAlive>, pub keep_alive: Option<KeepAlive>,
} }
// This could be dynamically retrieved via the API (1 call per model)
// curl -s http://localhost:11434/api/show -d '{"model": "llama3.1:latest"}' | jq '.model_info."llama.context_length"'
fn get_max_tokens(name: &str) -> usize { fn get_max_tokens(name: &str) -> usize {
match name { /// Default context length for unknown models.
"dolphin-llama3:8b-256k" => 262144, // 256K const DEFAULT_TOKENS: usize = 2048;
_ => match name.split(':').next().unwrap() { /// Magic number. Lets many Ollama models work with ~16GB of ram.
"mistral-nemo" => 1024000, // 1M const MAXIMUM_TOKENS: usize = 16384;
"deepseek-coder-v2" => 163840, // 160K
"llama3.1" | "phi3" | "command-r" | "command-r-plus" => 131072, // 128K match name.split(':').next().unwrap() {
"codeqwen" => 65536, // 64K "phi" | "tinyllama" | "granite-code" => 2048,
"mistral" | "mistral-large" | "dolphin-mistral" | "codestral" // 32K "llama2" | "yi" | "vicuna" | "stablelm2" => 4096,
| "mistral-openorca" | "dolphin-mixtral" | "mixstral" | "llava" "llama3" | "gemma2" | "gemma" | "codegemma" | "starcoder" | "aya" => 8192,
| "qwen" | "qwen2" | "wizardlm2" | "wizard-math" => 32768, "codellama" | "starcoder2" => 16384,
"codellama" | "stable-code" | "deepseek-coder" | "starcoder2" // 16K "mistral" | "codestral" | "mixstral" | "llava" | "qwen2" | "dolphin-mixtral" => 32768,
| "wizardcoder" => 16384, "llama3.1" | "phi3" | "phi3.5" | "command-r" | "deepseek-coder-v2" => 128000,
"llama3" | "gemma2" | "gemma" | "codegemma" | "dolphin-llama3" // 8K _ => DEFAULT_TOKENS,
| "llava-llama3" | "starcoder" | "openchat" | "aya" => 8192,
"llama2" | "yi" | "llama2-chinese" | "vicuna" | "nous-hermes2" // 4K
| "stablelm2" => 4096,
"phi" | "orca-mini" | "tinyllama" | "granite-code" => 2048, // 2K
_ => 2048, // 2K (default)
},
} }
.clamp(1, MAXIMUM_TOKENS)
} }
impl Model { impl Model {
pub fn new(name: &str) -> Self { pub fn new(name: &str, display_name: Option<&str>, max_tokens: Option<usize>) -> Self {
Self { Self {
name: name.to_owned(), name: name.to_owned(),
max_tokens: get_max_tokens(name), display_name: display_name
.map(ToString::to_string)
.or_else(|| name.strip_suffix(":latest").map(ToString::to_string)),
max_tokens: max_tokens.unwrap_or_else(|| get_max_tokens(name)),
keep_alive: Some(KeepAlive::indefinite()), keep_alive: Some(KeepAlive::indefinite()),
} }
} }
@ -109,7 +106,7 @@ impl Model {
} }
pub fn display_name(&self) -> &str { pub fn display_name(&self) -> &str {
&self.name self.display_name.as_ref().unwrap_or(&self.name)
} }
pub fn max_token_count(&self) -> usize { pub fn max_token_count(&self) -> usize {

View file

@ -108,33 +108,49 @@ Custom models will be listed in the model dropdown in the assistant panel.
Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`. Download and install Ollama from [ollama.com/download](https://ollama.com/download) (Linux or macOS) and ensure it's running with `ollama --version`.
You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint. 1. Download one of the [available models](https://ollama.com/models), for example, for `mistral`:
1. Download, for example, the `mistral` model with Ollama:
```sh ```sh
ollama pull mistral ollama pull mistral
``` ```
2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching: 2. Make sure that the Ollama server is running. You can start it either via running Ollama.app (MacOS) or launching:
```sh ```sh
ollama serve ollama serve
``` ```
3. In the assistant panel, select one of the Ollama models using the model dropdown. 3. In the assistant panel, select one of the Ollama models using the model dropdown.
4. (Optional) If you want to change the default URL that is used to access the Ollama server, you can do so by adding the following settings:
4. (Optional) Specify a [custom api_url](#custom-endpoint) or [custom `low_speed_timeout_in_seconds`](#provider-timeout) if required.
#### Ollama Context Length {#ollama-context}}
Zed has pre-configured maximum context lengths (`max_tokens`) to match the capabilities of common models. Zed API requests to Ollama include this as `num_ctx` parameter, but the default values do not exceed `16384` so users with ~16GB of ram are able to use most models out of the box. See [get_max_tokens in ollama.rs](https://github.com/zed-industries/zed/blob/main/crates/ollama/src/ollama.rs) for a complete set of defaults.
**Note**: Tokens counts displayed in the assistant panel are only estimates and will differ from the models native tokenizer.
Depending on your hardware or use-case you may wish to limit or increase the context length for a specific model via settings.json:
```json ```json
{ {
"language_models": { "language_models": {
"ollama": { "ollama": {
"api_url": "http://localhost:11434" "low_speed_timeout_in_seconds": 120,
"available_models": [
{
"provider": "ollama",
"name": "mistral:latest",
"max_tokens": 32768
}
]
} }
} }
} }
``` ```
If you specify a context length that is too large for your hardware, Ollama will log an error. You can watch these logs by running: `tail -f ~/.ollama/logs/ollama.log` (MacOS) or `journalctl -u ollama -f` (Linux). Depending on the memory available on your machine, you may need to adjust the context length to a smaller value.
### OpenAI {#openai} ### OpenAI {#openai}
1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys) 1. Visit the OpenAI platform and [create an API key](https://platform.openai.com/account/api-keys)