Ollama max_tokens settings (#17025)

- Support `available_models` for Ollama
- Clamp default max tokens (context length) to 16384.
- Add documentation for ollama context configuration.
This commit is contained in:
Peter Tripp 2024-08-30 12:52:00 +00:00 committed by GitHub
parent d401ab1efc
commit b62e63349b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 92 additions and 35 deletions

View file

@ -6,8 +6,10 @@ use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
ChatResponseDelta, OllamaToolCall,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
@ -28,6 +30,17 @@ const PROVIDER_NAME: &str = "Ollama";
pub struct OllamaSettings {
pub api_url: String,
pub low_speed_timeout: Option<Duration>,
pub available_models: Vec<AvailableModel>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
pub struct AvailableModel {
/// The model name in the Ollama API (e.g. "llama3.1:latest")
pub name: String,
/// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
pub display_name: Option<String>,
/// The Context Length parameter to the model (aka num_ctx or n_ctx)
pub max_tokens: usize,
}
pub struct OllamaLanguageModelProvider {
@ -61,7 +74,7 @@ impl State {
// indicating which models are embedding models,
// simply filter out models with "-embed" in their name
.filter(|model| !model.name.contains("-embed"))
.map(|model| ollama::Model::new(&model.name))
.map(|model| ollama::Model::new(&model.name, None, None))
.collect();
models.sort_by(|a, b| a.name.cmp(&b.name));
@ -123,10 +136,32 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
self.state
.read(cx)
let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
// Add models from the Ollama API
for model in self.state.read(cx).available_models.iter() {
models.insert(model.name.clone(), model.clone());
}
// Override with available models from settings
for model in AllLanguageModelSettings::get_global(cx)
.ollama
.available_models
.iter()
{
models.insert(
model.name.clone(),
ollama::Model {
name: model.name.clone(),
display_name: model.display_name.clone(),
max_tokens: model.max_tokens,
keep_alive: None,
},
);
}
models
.into_values()
.map(|model| {
Arc::new(OllamaLanguageModel {
id: LanguageModelId::from(model.name.clone()),

View file

@ -152,6 +152,7 @@ pub struct AnthropicSettingsContentV1 {
pub struct OllamaSettingsContent {
pub api_url: Option<String>,
pub low_speed_timeout_in_seconds: Option<u64>,
pub available_models: Option<Vec<provider::ollama::AvailableModel>>,
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, JsonSchema)]
@ -276,6 +277,9 @@ impl settings::Settings for AllLanguageModelSettings {
anthropic.as_ref().and_then(|s| s.available_models.clone()),
);
// Ollama
let ollama = value.ollama.clone();
merge(
&mut settings.ollama.api_url,
value.ollama.as_ref().and_then(|s| s.api_url.clone()),
@ -288,6 +292,10 @@ impl settings::Settings for AllLanguageModelSettings {
settings.ollama.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
&mut settings.ollama.available_models,
ollama.as_ref().and_then(|s| s.available_models.clone()),
);
// OpenAI
let (openai, upgraded) = match value.openai.clone().map(|s| s.upgrade()) {