Allow Anthropic custom models to override temperature (#18160)

Release Notes:

- Allow Anthropic custom models to override "temperature"

This also centralized the defaulting of "temperature" to be inside of
each model's `into_x` call instead of being sprinkled around the code.
This commit is contained in:
Roy Williams 2024-09-20 16:59:12 -04:00 committed by GitHub
parent 7d62fda5a3
commit 5905fbb9ac
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 54 additions and 17 deletions

View file

@ -51,6 +51,7 @@ pub struct AvailableModel {
/// Configuration of Anthropic's caching API.
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
pub max_output_tokens: Option<u32>,
pub default_temperature: Option<f32>,
}
pub struct AnthropicLanguageModelProvider {
@ -200,6 +201,7 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
}),
max_output_tokens: model.max_output_tokens,
default_temperature: model.default_temperature,
},
);
}
@ -375,8 +377,11 @@ impl LanguageModel for AnthropicModel {
request: LanguageModelRequest,
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
let request =
request.into_anthropic(self.model.id().into(), self.model.max_output_tokens());
let request = request.into_anthropic(
self.model.id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request.await.map_err(|err| anyhow!(err))?;
@ -405,6 +410,7 @@ impl LanguageModel for AnthropicModel {
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
let mut request = request.into_anthropic(
self.model.tool_model_id().into(),
self.model.default_temperature(),
self.model.max_output_tokens(),
);
request.tool_choice = Some(anthropic::ToolChoice::Tool {

View file

@ -87,6 +87,8 @@ pub struct AvailableModel {
pub tool_override: Option<String>,
/// Indicates whether this custom model supports caching.
pub cache_configuration: Option<LanguageModelCacheConfiguration>,
/// The default temperature to use for this model.
pub default_temperature: Option<f32>,
}
pub struct CloudLanguageModelProvider {
@ -255,6 +257,7 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
min_total_token: config.min_total_token,
}
}),
default_temperature: model.default_temperature,
max_output_tokens: model.max_output_tokens,
}),
AvailableProvider::OpenAi => CloudModel::OpenAi(open_ai::Model::Custom {
@ -516,7 +519,11 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into(), model.max_output_tokens());
let request = request.into_anthropic(
model.id().into(),
model.default_temperature(),
model.max_output_tokens(),
);
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
@ -642,8 +649,11 @@ impl LanguageModel for CloudLanguageModel {
match &self.model {
CloudModel::Anthropic(model) => {
let mut request =
request.into_anthropic(model.tool_model_id().into(), model.max_output_tokens());
let mut request = request.into_anthropic(
model.tool_model_id().into(),
model.default_temperature(),
model.max_output_tokens(),
);
request.tool_choice = Some(anthropic::ToolChoice::Tool {
name: tool_name.clone(),
});

View file

@ -235,7 +235,7 @@ impl OllamaLanguageModel {
options: Some(ChatOptions {
num_ctx: Some(self.model.max_tokens),
stop: Some(request.stop),
temperature: Some(request.temperature),
temperature: request.temperature.or(Some(1.0)),
..Default::default()
}),
tools: vec![],