chore: switch to gemani

This commit is contained in:
Sandipsinh Rathod 2024-10-13 00:58:13 -04:00
parent e35e499460
commit 459be66406
No known key found for this signature in database
18 changed files with 142 additions and 114 deletions

@ -30,7 +30,9 @@ impl HttpCacheManager {
.eviction_policy(EvictionPolicy::lru())
.max_capacity(cache_size)
.build();
Self { cache: Arc::new(cache) }
Self {
cache: Arc::new(cache),
}
}
pub async fn clear(&self) -> Result<()> {
@ -56,7 +58,10 @@ impl CacheManager for HttpCacheManager {
response: HttpResponse,
policy: CachePolicy,
) -> Result<HttpResponse> {
let data = Store { response: response.clone(), policy };
let data = Store {
response: response.clone(),
policy,
};
self.cache.insert(cache_key, data).await;
self.cache.run_pending_tasks().await;
Ok(response)

@ -1,12 +1,15 @@
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use hyper::Response;
use crate::app_ctx::AppCtx;
use crate::http::request::Request;
use crate::margdarshak::model::Question;
use bytes::Bytes;
use http_body_util::Full;
use hyper::Response;
use std::sync::Arc;
pub async fn handle_request(req: Request, app_ctx: Arc<AppCtx>) -> anyhow::Result<Response<Full<Bytes>>> {
pub async fn handle_request(
req: Request,
app_ctx: Arc<AppCtx>,
) -> anyhow::Result<Response<Full<Bytes>>> {
let body = String::from_utf8(req.body.to_vec())?;
tracing::info!("{}", body);
let body = Question::process(&app_ctx.md, &body);

@ -1,2 +1,2 @@
pub mod request;
pub mod handle_req;
pub mod request;

@ -1,6 +1,6 @@
use bytes::Bytes;
use hyper::body::Incoming;
use http_body_util::BodyExt;
use hyper::body::Incoming;
pub struct Request {
pub method: String,
@ -17,7 +17,17 @@ impl Request {
let method = parts.method.to_string();
let path = parts.uri.path().to_string();
let query = parts.uri.query().map(|q| q.to_string());
let headers = parts.headers.iter().map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string())).collect();
Ok(Self { method, path, query, headers, body })
let headers = parts
.headers
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string()))
.collect();
Ok(Self {
method,
path,
query,
headers,
body,
})
}
}

20
src/io/env.rs Normal file

@ -0,0 +1,20 @@
use crate::target_rt::EnvIO;
use std::collections::HashMap;
pub struct NativeEnvIO {
map: HashMap<String, String>,
}
impl NativeEnvIO {
pub fn new() -> Self {
Self {
map: std::env::vars().collect(),
}
}
}
impl EnvIO for NativeEnvIO {
fn get_env(&self, key: &str) -> Option<String> {
self.map.get(key).cloned()
}
}

@ -1,5 +1,5 @@
use bytes::Bytes;
use crate::target_rt::FileIO;
use bytes::Bytes;
pub struct NativeFileIO {}

@ -1,9 +1,9 @@
use crate::target_rt::HttpIO;
use anyhow::anyhow;
use http_cache_reqwest::{Cache, CacheMode, HttpCache, HttpCacheOptions};
use margdarshak_http_cache::HttpCacheManager;
use reqwest::{Client, Request};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use margdarshak_http_cache::HttpCacheManager;
use crate::target_rt::HttpIO;
pub struct NativeHttpIO {
client: ClientWithMiddleware,
@ -30,7 +30,11 @@ impl NativeHttpIO {
#[async_trait::async_trait]
impl HttpIO for NativeHttpIO {
async fn execute(&self, request: Request) -> anyhow::Result<reqwest::Response> {
let resp = self.client.execute(request).await.map_err(|e| anyhow!("{}", e))?;
let resp = self
.client
.execute(request)
.await
.map_err(|e| anyhow!("{}", e))?;
Ok(resp)
}
}

@ -1,2 +1,3 @@
pub mod env;
pub mod fs;
pub mod http;

@ -1,6 +1,6 @@
pub mod target_rt;
pub mod io;
pub mod scrapper;
pub mod margdarshak;
pub mod http;
pub mod app_ctx;
pub mod http;
pub mod io;
pub mod margdarshak;
pub mod scrapper;
pub mod target_rt;

@ -1,10 +1,8 @@
Based on the attached text/link below and answer questions accordingly after I write "question:".
Please do not copy paste the text in the answer.
Write the answer in your own words.
And you should focus on mentioning link for the source if possible.
Lastly, do not provide any images, just respond in text.
Lastly, do not provide any images, just respond in text in your own words.
text:
{{input}}
@ -12,4 +10,4 @@ text:
question:
{{question}}
Make sure to respond in MD format only.
Make sure to respond in markdown format.

@ -1,17 +1,14 @@
use std::net::SocketAddr;
use std::sync::Arc;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use crate::app_ctx::AppCtx;
use crate::http::handle_req::handle_request;
use crate::http::request::Request;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
pub async fn start_http_1(
port: u16,
app_ctx: AppCtx,
) -> anyhow::Result<()> {
pub async fn start_http_1(port: u16, app_ctx: AppCtx) -> anyhow::Result<()> {
let addr = SocketAddr::new([0, 0, 0, 0].into(), port);
let listener = TcpListener::bind(addr).await?;
let app_ctx = Arc::new(app_ctx);
@ -29,17 +26,11 @@ pub async fn start_http_1(
.serve_connection(
io,
service_fn(move |req: hyper::Request<Incoming>| {
{
let app_ctx = app_ctx.clone();
async move {
let req = Request::from_hyper(req).await?;
handle_request(
req,
app_ctx.clone(),
)
.await
}
handle_request(req, app_ctx.clone()).await
}
}),
)

@ -1,26 +1,3 @@
// use qdrant_client::prelude::PointStruct;
// use qdrant_client::qdrant::{Vector, Vectors};
// use qdrant_client::qdrant::vectors::VectorsOptions;
/*fn chunk_md_content(content: &str) -> Vec<String> {
let chunks: Vec<String> = content.split("\n\n").map(|s| s.trim().to_string()).collect();
chunks
}
*/
/*fn store_embeddings(text_chunks: Vec<String>, client: &qdrant_client::Qdrant) {
for (i, chunk) in text_chunks.iter().enumerate() {
let vector = Vector { data: embedding.to_vec(), indices: None, vectors_count: None };
let options = VectorsOptions::Vector(vector);
let point = PointStruct {
id: i as u64,
vectors: Some(Vectors { vectors_options: Some(options) }),
..Default::default()
};
client.upload_point("your_collection", point).unwrap();
}
}*/
pub mod model;
pub mod http1;
pub mod model;
pub mod runner;

@ -1,8 +1,8 @@
use anyhow::{anyhow, Result};
use genai::adapter::AdapterKind;
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponse};
use genai::Client;
use genai::resolver::AuthResolver;
use anyhow::{anyhow, Result};
use genai::Client;
const BASE: &str = include_str!("helper_md/query.md");
@ -34,7 +34,6 @@ impl TryInto<ChatRequest> for Question<'_> {
type Error = anyhow::Error;
fn try_into(self) -> Result<ChatRequest> {
Ok(ChatRequest::new(vec![
ChatMessage::system(self.md),
ChatMessage::user(self.qry),
@ -51,13 +50,16 @@ impl TryFrom<ChatResponse> for Answer {
fn try_from(response: ChatResponse) -> Result<Self> {
let message_content = response.content.ok_or(anyhow!("No response found"))?;
let text_content = message_content.text_as_str().ok_or(anyhow!("Unable to deserialize response"))?;
let text_content = message_content
.text_as_str()
.ok_or(anyhow!("Unable to deserialize response"))?;
// Ok(serde_json::from_str(text_content)?)
Ok(Self { ans: text_content.to_string() })
Ok(Self {
ans: text_content.to_string(),
})
}
}
impl Wizard {
pub fn new(model: String, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default();
@ -65,9 +67,7 @@ impl Wizard {
config = config.with_auth_resolver(AuthResolver::from_key_value(key));
}
let adapter_kind = AdapterKind::from_model(model.as_str())
.unwrap_or(AdapterKind::Ollama);
let adapter_kind = AdapterKind::from_model(model.as_str()).unwrap_or(AdapterKind::Ollama);
let chat_options = ChatOptions::default()
.with_json_mode(true)
@ -87,6 +87,6 @@ impl Wizard {
.exec_chat(self.model.as_str(), q.try_into()?, None)
.await?;
Answer::try_from(response).map_err(|e| anyhow!("{}",e.to_string()))
Answer::try_from(response).map_err(|e| anyhow!("{}", e.to_string()))
}
}

@ -1,34 +1,43 @@
use std::sync::Arc;
use crate::app_ctx::AppCtx;
use crate::io::env::NativeEnvIO;
use crate::io::fs::NativeFileIO;
use crate::io::http::NativeHttpIO;
use crate::margdarshak::http1;
use crate::margdarshak::model::{Question, Wizard};
use crate::target_rt::TargetRuntime;
use std::sync::Arc;
pub async fn run() -> anyhow::Result<()> {
let rt = TargetRuntime {
http: Arc::new(NativeHttpIO::new()),
fs: Arc::new(NativeFileIO {}),
env: Arc::new(NativeEnvIO::new()),
};
let md = crate::scrapper::scrape::handle("https://icds-docs.readthedocs.io/en/latest/".to_string(), &rt).await?;
let md = crate::scrapper::scrape::handle(
"https://icds-docs.readthedocs.io/en/latest/".to_string(),
&rt,
)
.await?;
// let md = "https://icds-docs.readthedocs.io/en/latest/".to_string();
// TODO: add wrapper structs to avoid accidentally
// query instead of MD and vis versa.
let query = "".to_string();
let query = Question::process(&md, &query);
let query = "This is the input, now you should refer this input and answer following questions"
.to_string();
// let query = "".to_string();
// let query = Question::process(&md, &query);
let query = Question::process(&md, &query);
let question = Question::new(&md, &query);
let wizard = Wizard::new("llama3.2".to_string(), None);
let wizard = Wizard::new(
"gemini-1.5-flash-latest".to_string(),
rt.env.get_env("API_KEY"),
);
tracing::info!("Warming up!");
let _ans = wizard.ask(question).await?;
tracing::info!("Warmup complete.");
let app_ctx = AppCtx {
wizard,
md,
};
let app_ctx = AppCtx { wizard, md };
http1::start_http_1(19194, app_ctx).await?;
Ok(())

@ -1,6 +1,6 @@
use crate::target_rt::TargetRuntime;
use reqwest::Url;
use scraper::{Html, Selector};
use crate::target_rt::TargetRuntime;
async fn handle_inner(url: &str, runtime: &TargetRuntime) -> anyhow::Result<String> {
let url = Url::parse(url)?;
@ -20,10 +20,16 @@ pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result<
loop {
let parsed = Html::parse_document(&root);
let val =
Selector::parse("a")
.ok()
.and_then(|link_selector| parsed.select(&link_selector).find(|val| val.text().collect::<Vec<_>>().join(" ").trim().to_lowercase().eq("next")));
let val = Selector::parse("a").ok().and_then(|link_selector| {
parsed.select(&link_selector).find(|val| {
val.text()
.collect::<Vec<_>>()
.join(" ")
.trim()
.to_lowercase()
.eq("next")
})
});
match val {
Some(val) => {
@ -38,9 +44,12 @@ pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result<
break;
}
}
};
let mds = htmls.into_iter().map(|v| htmd::convert(v.as_str())).flatten().collect::<Vec<_>>();
}
let mds = htmls
.into_iter()
.map(|v| htmd::convert(v.as_str()))
.flatten()
.collect::<Vec<_>>();
let text = mds.join("\n");
// let text: String = md_to_text::convert(&text);
Ok(text)
}

@ -1,15 +1,15 @@
use std::sync::Arc;
use bytes::Bytes;
use std::sync::Arc;
pub trait EnvIO {
fn get_env(&self, key: &str) -> Option<String>;
}
#[async_trait::async_trait]
pub trait HttpIO: Sync + Send + 'static {
async fn execute(
&self,
request: reqwest::Request,
) -> anyhow::Result<reqwest::Response>;
async fn execute(&self, request: reqwest::Request) -> anyhow::Result<reqwest::Response>;
}
#[async_trait::async_trait]
pub trait FileIO: Send + Sync {
async fn write<'a>(&'a self, path: &'a str, content: &'a Bytes) -> anyhow::Result<()>;
@ -19,4 +19,5 @@ pub trait FileIO: Send + Sync {
pub struct TargetRuntime {
pub http: Arc<dyn HttpIO>,
pub fs: Arc<dyn FileIO>,
pub env: Arc<dyn EnvIO>,
}