Ollama improvements (#12921)

Attempt to load the model early on when the user has switched the model.

This is a follow up to #12902

Release Notes:

- N/A
This commit is contained in:
Kyle Kelley 2024-06-12 08:10:51 -07:00 committed by GitHub
parent 113546f766
commit bee3441c78
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 67 additions and 7 deletions

View file

@ -42,18 +42,14 @@ impl From<Role> for String {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Model {
pub name: String,
pub parameter_size: String,
pub max_tokens: usize,
pub keep_alive: Option<String>,
}
impl Model {
pub fn new(name: &str, parameter_size: &str) -> Self {
pub fn new(name: &str) -> Self {
Self {
name: name.to_owned(),
parameter_size: parameter_size.to_owned(),
// todo: determine if there's an endpoint to find the max tokens
// I'm not seeing it in the API docs but it's on the model cards
max_tokens: 2048,
keep_alive: Some("10m".to_owned()),
}
@ -222,3 +218,43 @@ pub async fn get_models(
))
}
}
/// Sends an empty request to Ollama to trigger loading the model
pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate");
let request = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(AsyncBody::from(serde_json::to_string(
&serde_json::json!({
"model": model,
"keep_alive": "15m",
}),
)?))?;
let mut response = match client.send(request).await {
Ok(response) => response,
Err(err) => {
// Be ok with a timeout during preload of the model
if err.is_timeout() {
return Ok(());
} else {
return Err(err.into());
}
}
};
if response.status().is_success() {
Ok(())
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
Err(anyhow!(
"Failed to connect to Ollama API: {} {}",
response.status(),
body,
))
}
}