Merge pull request 'chore: switch to gemani' (#2) from chore/switch-to-gemani into main

Reviewed-on: #2
This commit is contained in:
ssdd 2024-10-13 05:04:17 +00:00
commit 3ccff4a026
18 changed files with 142 additions and 114 deletions

@ -30,7 +30,9 @@ impl HttpCacheManager {
.eviction_policy(EvictionPolicy::lru()) .eviction_policy(EvictionPolicy::lru())
.max_capacity(cache_size) .max_capacity(cache_size)
.build(); .build();
Self { cache: Arc::new(cache) } Self {
cache: Arc::new(cache),
}
} }
pub async fn clear(&self) -> Result<()> { pub async fn clear(&self) -> Result<()> {
@ -56,7 +58,10 @@ impl CacheManager for HttpCacheManager {
response: HttpResponse, response: HttpResponse,
policy: CachePolicy, policy: CachePolicy,
) -> Result<HttpResponse> { ) -> Result<HttpResponse> {
let data = Store { response: response.clone(), policy }; let data = Store {
response: response.clone(),
policy,
};
self.cache.insert(cache_key, data).await; self.cache.insert(cache_key, data).await;
self.cache.run_pending_tasks().await; self.cache.run_pending_tasks().await;
Ok(response) Ok(response)

@ -1,12 +1,15 @@
use std::sync::Arc;
use bytes::Bytes;
use http_body_util::Full;
use hyper::Response;
use crate::app_ctx::AppCtx; use crate::app_ctx::AppCtx;
use crate::http::request::Request; use crate::http::request::Request;
use crate::margdarshak::model::Question; use crate::margdarshak::model::Question;
use bytes::Bytes;
use http_body_util::Full;
use hyper::Response;
use std::sync::Arc;
pub async fn handle_request(req: Request, app_ctx: Arc<AppCtx>) -> anyhow::Result<Response<Full<Bytes>>> { pub async fn handle_request(
req: Request,
app_ctx: Arc<AppCtx>,
) -> anyhow::Result<Response<Full<Bytes>>> {
let body = String::from_utf8(req.body.to_vec())?; let body = String::from_utf8(req.body.to_vec())?;
tracing::info!("{}", body); tracing::info!("{}", body);
let body = Question::process(&app_ctx.md, &body); let body = Question::process(&app_ctx.md, &body);

@ -1,2 +1,2 @@
pub mod request;
pub mod handle_req; pub mod handle_req;
pub mod request;

@ -1,6 +1,6 @@
use bytes::Bytes; use bytes::Bytes;
use hyper::body::Incoming;
use http_body_util::BodyExt; use http_body_util::BodyExt;
use hyper::body::Incoming;
pub struct Request { pub struct Request {
pub method: String, pub method: String,
@ -17,7 +17,17 @@ impl Request {
let method = parts.method.to_string(); let method = parts.method.to_string();
let path = parts.uri.path().to_string(); let path = parts.uri.path().to_string();
let query = parts.uri.query().map(|q| q.to_string()); let query = parts.uri.query().map(|q| q.to_string());
let headers = parts.headers.iter().map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string())).collect(); let headers = parts
Ok(Self { method, path, query, headers, body }) .headers
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.to_str().unwrap().to_string()))
.collect();
Ok(Self {
method,
path,
query,
headers,
body,
})
} }
} }

20
src/io/env.rs Normal file

@ -0,0 +1,20 @@
use crate::target_rt::EnvIO;
use std::collections::HashMap;
pub struct NativeEnvIO {
map: HashMap<String, String>,
}
impl NativeEnvIO {
pub fn new() -> Self {
Self {
map: std::env::vars().collect(),
}
}
}
impl EnvIO for NativeEnvIO {
fn get_env(&self, key: &str) -> Option<String> {
self.map.get(key).cloned()
}
}

@ -1,5 +1,5 @@
use bytes::Bytes;
use crate::target_rt::FileIO; use crate::target_rt::FileIO;
use bytes::Bytes;
pub struct NativeFileIO {} pub struct NativeFileIO {}

@ -1,9 +1,9 @@
use crate::target_rt::HttpIO;
use anyhow::anyhow; use anyhow::anyhow;
use http_cache_reqwest::{Cache, CacheMode, HttpCache, HttpCacheOptions}; use http_cache_reqwest::{Cache, CacheMode, HttpCache, HttpCacheOptions};
use margdarshak_http_cache::HttpCacheManager;
use reqwest::{Client, Request}; use reqwest::{Client, Request};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware}; use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use margdarshak_http_cache::HttpCacheManager;
use crate::target_rt::HttpIO;
pub struct NativeHttpIO { pub struct NativeHttpIO {
client: ClientWithMiddleware, client: ClientWithMiddleware,
@ -30,7 +30,11 @@ impl NativeHttpIO {
#[async_trait::async_trait] #[async_trait::async_trait]
impl HttpIO for NativeHttpIO { impl HttpIO for NativeHttpIO {
async fn execute(&self, request: Request) -> anyhow::Result<reqwest::Response> { async fn execute(&self, request: Request) -> anyhow::Result<reqwest::Response> {
let resp = self.client.execute(request).await.map_err(|e| anyhow!("{}", e))?; let resp = self
.client
.execute(request)
.await
.map_err(|e| anyhow!("{}", e))?;
Ok(resp) Ok(resp)
} }
} }

@ -1,2 +1,3 @@
pub mod env;
pub mod fs; pub mod fs;
pub mod http; pub mod http;

@ -1,6 +1,6 @@
pub mod target_rt;
pub mod io;
pub mod scrapper;
pub mod margdarshak;
pub mod http;
pub mod app_ctx; pub mod app_ctx;
pub mod http;
pub mod io;
pub mod margdarshak;
pub mod scrapper;
pub mod target_rt;

@ -1,10 +1,8 @@
Based on the attached text/link below and answer questions accordingly after I write "question:". Based on the attached text/link below and answer questions accordingly after I write "question:".
Please do not copy paste the text in the answer.
Write the answer in your own words.
And you should focus on mentioning link for the source if possible. And you should focus on mentioning link for the source if possible.
Lastly, do not provide any images, just respond in text. Lastly, do not provide any images, just respond in text in your own words.
text: text:
{{input}} {{input}}
@ -12,4 +10,4 @@ text:
question: question:
{{question}} {{question}}
Make sure to respond in MD format only. Make sure to respond in markdown format.

@ -1,17 +1,14 @@
use std::net::SocketAddr;
use std::sync::Arc;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use tokio::net::TcpListener;
use crate::app_ctx::AppCtx; use crate::app_ctx::AppCtx;
use crate::http::handle_req::handle_request; use crate::http::handle_req::handle_request;
use crate::http::request::Request; use crate::http::request::Request;
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper_util::rt::TokioIo;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
pub async fn start_http_1( pub async fn start_http_1(port: u16, app_ctx: AppCtx) -> anyhow::Result<()> {
port: u16,
app_ctx: AppCtx,
) -> anyhow::Result<()> {
let addr = SocketAddr::new([0, 0, 0, 0].into(), port); let addr = SocketAddr::new([0, 0, 0, 0].into(), port);
let listener = TcpListener::bind(addr).await?; let listener = TcpListener::bind(addr).await?;
let app_ctx = Arc::new(app_ctx); let app_ctx = Arc::new(app_ctx);
@ -29,17 +26,11 @@ pub async fn start_http_1(
.serve_connection( .serve_connection(
io, io,
service_fn(move |req: hyper::Request<Incoming>| { service_fn(move |req: hyper::Request<Incoming>| {
{
let app_ctx = app_ctx.clone(); let app_ctx = app_ctx.clone();
async move { async move {
let req = Request::from_hyper(req).await?; let req = Request::from_hyper(req).await?;
handle_request( handle_request(req, app_ctx.clone()).await
req,
app_ctx.clone(),
)
.await
}
} }
}), }),
) )

@ -1,26 +1,3 @@
// use qdrant_client::prelude::PointStruct;
// use qdrant_client::qdrant::{Vector, Vectors};
// use qdrant_client::qdrant::vectors::VectorsOptions;
/*fn chunk_md_content(content: &str) -> Vec<String> {
let chunks: Vec<String> = content.split("\n\n").map(|s| s.trim().to_string()).collect();
chunks
}
*/
/*fn store_embeddings(text_chunks: Vec<String>, client: &qdrant_client::Qdrant) {
for (i, chunk) in text_chunks.iter().enumerate() {
let vector = Vector { data: embedding.to_vec(), indices: None, vectors_count: None };
let options = VectorsOptions::Vector(vector);
let point = PointStruct {
id: i as u64,
vectors: Some(Vectors { vectors_options: Some(options) }),
..Default::default()
};
client.upload_point("your_collection", point).unwrap();
}
}*/
pub mod model;
pub mod http1; pub mod http1;
pub mod model;
pub mod runner; pub mod runner;

@ -1,8 +1,8 @@
use anyhow::{anyhow, Result};
use genai::adapter::AdapterKind; use genai::adapter::AdapterKind;
use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponse}; use genai::chat::{ChatMessage, ChatOptions, ChatRequest, ChatResponse};
use genai::Client;
use genai::resolver::AuthResolver; use genai::resolver::AuthResolver;
use anyhow::{anyhow, Result}; use genai::Client;
const BASE: &str = include_str!("helper_md/query.md"); const BASE: &str = include_str!("helper_md/query.md");
@ -34,7 +34,6 @@ impl TryInto<ChatRequest> for Question<'_> {
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_into(self) -> Result<ChatRequest> { fn try_into(self) -> Result<ChatRequest> {
Ok(ChatRequest::new(vec![ Ok(ChatRequest::new(vec![
ChatMessage::system(self.md), ChatMessage::system(self.md),
ChatMessage::user(self.qry), ChatMessage::user(self.qry),
@ -51,13 +50,16 @@ impl TryFrom<ChatResponse> for Answer {
fn try_from(response: ChatResponse) -> Result<Self> { fn try_from(response: ChatResponse) -> Result<Self> {
let message_content = response.content.ok_or(anyhow!("No response found"))?; let message_content = response.content.ok_or(anyhow!("No response found"))?;
let text_content = message_content.text_as_str().ok_or(anyhow!("Unable to deserialize response"))?; let text_content = message_content
.text_as_str()
.ok_or(anyhow!("Unable to deserialize response"))?;
// Ok(serde_json::from_str(text_content)?) // Ok(serde_json::from_str(text_content)?)
Ok(Self { ans: text_content.to_string() }) Ok(Self {
ans: text_content.to_string(),
})
} }
} }
impl Wizard { impl Wizard {
pub fn new(model: String, secret: Option<String>) -> Self { pub fn new(model: String, secret: Option<String>) -> Self {
let mut config = genai::adapter::AdapterConfig::default(); let mut config = genai::adapter::AdapterConfig::default();
@ -65,9 +67,7 @@ impl Wizard {
config = config.with_auth_resolver(AuthResolver::from_key_value(key)); config = config.with_auth_resolver(AuthResolver::from_key_value(key));
} }
let adapter_kind = AdapterKind::from_model(model.as_str()).unwrap_or(AdapterKind::Ollama);
let adapter_kind = AdapterKind::from_model(model.as_str())
.unwrap_or(AdapterKind::Ollama);
let chat_options = ChatOptions::default() let chat_options = ChatOptions::default()
.with_json_mode(true) .with_json_mode(true)
@ -87,6 +87,6 @@ impl Wizard {
.exec_chat(self.model.as_str(), q.try_into()?, None) .exec_chat(self.model.as_str(), q.try_into()?, None)
.await?; .await?;
Answer::try_from(response).map_err(|e| anyhow!("{}",e.to_string())) Answer::try_from(response).map_err(|e| anyhow!("{}", e.to_string()))
} }
} }

@ -1,34 +1,43 @@
use std::sync::Arc;
use crate::app_ctx::AppCtx; use crate::app_ctx::AppCtx;
use crate::io::env::NativeEnvIO;
use crate::io::fs::NativeFileIO; use crate::io::fs::NativeFileIO;
use crate::io::http::NativeHttpIO; use crate::io::http::NativeHttpIO;
use crate::margdarshak::http1; use crate::margdarshak::http1;
use crate::margdarshak::model::{Question, Wizard}; use crate::margdarshak::model::{Question, Wizard};
use crate::target_rt::TargetRuntime; use crate::target_rt::TargetRuntime;
use std::sync::Arc;
pub async fn run() -> anyhow::Result<()> { pub async fn run() -> anyhow::Result<()> {
let rt = TargetRuntime { let rt = TargetRuntime {
http: Arc::new(NativeHttpIO::new()), http: Arc::new(NativeHttpIO::new()),
fs: Arc::new(NativeFileIO {}), fs: Arc::new(NativeFileIO {}),
env: Arc::new(NativeEnvIO::new()),
}; };
let md = crate::scrapper::scrape::handle("https://icds-docs.readthedocs.io/en/latest/".to_string(), &rt).await?; let md = crate::scrapper::scrape::handle(
"https://icds-docs.readthedocs.io/en/latest/".to_string(),
&rt,
)
.await?;
// let md = "https://icds-docs.readthedocs.io/en/latest/".to_string();
// TODO: add wrapper structs to avoid accidentally // TODO: add wrapper structs to avoid accidentally
// query instead of MD and vis versa. // query instead of MD and vis versa.
let query = "".to_string(); let query = "This is the input, now you should refer this input and answer following questions"
let query = Question::process(&md, &query); .to_string();
// let query = "".to_string();
// let query = Question::process(&md, &query);
let query = Question::process(&md, &query); let query = Question::process(&md, &query);
let question = Question::new(&md, &query); let question = Question::new(&md, &query);
let wizard = Wizard::new("llama3.2".to_string(), None); let wizard = Wizard::new(
"gemini-1.5-flash-latest".to_string(),
rt.env.get_env("API_KEY"),
);
tracing::info!("Warming up!"); tracing::info!("Warming up!");
let _ans = wizard.ask(question).await?; let _ans = wizard.ask(question).await?;
tracing::info!("Warmup complete."); tracing::info!("Warmup complete.");
let app_ctx = AppCtx { let app_ctx = AppCtx { wizard, md };
wizard,
md,
};
http1::start_http_1(19194, app_ctx).await?; http1::start_http_1(19194, app_ctx).await?;
Ok(()) Ok(())

@ -1,6 +1,6 @@
use crate::target_rt::TargetRuntime;
use reqwest::Url; use reqwest::Url;
use scraper::{Html, Selector}; use scraper::{Html, Selector};
use crate::target_rt::TargetRuntime;
async fn handle_inner(url: &str, runtime: &TargetRuntime) -> anyhow::Result<String> { async fn handle_inner(url: &str, runtime: &TargetRuntime) -> anyhow::Result<String> {
let url = Url::parse(url)?; let url = Url::parse(url)?;
@ -20,10 +20,16 @@ pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result<
loop { loop {
let parsed = Html::parse_document(&root); let parsed = Html::parse_document(&root);
let val = let val = Selector::parse("a").ok().and_then(|link_selector| {
Selector::parse("a") parsed.select(&link_selector).find(|val| {
.ok() val.text()
.and_then(|link_selector| parsed.select(&link_selector).find(|val| val.text().collect::<Vec<_>>().join(" ").trim().to_lowercase().eq("next"))); .collect::<Vec<_>>()
.join(" ")
.trim()
.to_lowercase()
.eq("next")
})
});
match val { match val {
Some(val) => { Some(val) => {
@ -38,9 +44,12 @@ pub async fn handle(mut url: String, runtime: &TargetRuntime) -> anyhow::Result<
break; break;
} }
} }
}; }
let mds = htmls.into_iter().map(|v| htmd::convert(v.as_str())).flatten().collect::<Vec<_>>(); let mds = htmls
.into_iter()
.map(|v| htmd::convert(v.as_str()))
.flatten()
.collect::<Vec<_>>();
let text = mds.join("\n"); let text = mds.join("\n");
// let text: String = md_to_text::convert(&text);
Ok(text) Ok(text)
} }

@ -1,15 +1,15 @@
use std::sync::Arc;
use bytes::Bytes; use bytes::Bytes;
use std::sync::Arc;
pub trait EnvIO {
fn get_env(&self, key: &str) -> Option<String>;
}
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait HttpIO: Sync + Send + 'static { pub trait HttpIO: Sync + Send + 'static {
async fn execute( async fn execute(&self, request: reqwest::Request) -> anyhow::Result<reqwest::Response>;
&self,
request: reqwest::Request,
) -> anyhow::Result<reqwest::Response>;
} }
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait FileIO: Send + Sync { pub trait FileIO: Send + Sync {
async fn write<'a>(&'a self, path: &'a str, content: &'a Bytes) -> anyhow::Result<()>; async fn write<'a>(&'a self, path: &'a str, content: &'a Bytes) -> anyhow::Result<()>;
@ -19,4 +19,5 @@ pub trait FileIO: Send + Sync {
pub struct TargetRuntime { pub struct TargetRuntime {
pub http: Arc<dyn HttpIO>, pub http: Arc<dyn HttpIO>,
pub fs: Arc<dyn FileIO>, pub fs: Arc<dyn FileIO>,
pub env: Arc<dyn EnvIO>,
} }