Created using Colaboratory

This commit is contained in:
sugarforever 2023-11-20 23:48:55 +00:00
parent 713c9bb462
commit d0c1f442cc

View File

@ -5,7 +5,7 @@
"colab": { "colab": {
"provenance": [], "provenance": [],
"gpuType": "T4", "gpuType": "T4",
"authorship_tag": "ABX9TyOsgwsp69IaZ0UMnvAVkdVX", "authorship_tag": "ABX9TyPnIDleZ4upjO9LLlSfEb5e",
"include_colab_link": true "include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
@ -1975,17 +1975,13 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Prompt\n",
"prompt_text = \"\"\"\n", "prompt_text = \"\"\"\n",
" You are responsible for concisely summarizing table or text chunk:\n", " You are responsible for concisely summarizing table or text chunk:\n",
"\n", "\n",
" {element}\n", " {element}\n",
"\"\"\"\n", "\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(prompt_text)\n", "prompt = ChatPromptTemplate.from_template(prompt_text)\n",
"\n", "summarize_chain = {\"element\": lambda x: x} | prompt | ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\") | StrOutputParser()"
"# Summarization chain\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-3.5-turbo\")\n",
"summarize_chain = {\"element\": lambda x: x} | prompt | model | StrOutputParser()"
], ],
"metadata": { "metadata": {
"id": "uDQYbnKDbM7C" "id": "uDQYbnKDbM7C"
@ -2005,10 +2001,9 @@
{ {
"cell_type": "code", "cell_type": "code",
"source": [ "source": [
"# Apply to tables\n",
"tables = [i.text for i in table_elements]\n", "tables = [i.text for i in table_elements]\n",
"table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})\n", "table_summaries = summarize_chain.batch(tables, {\"max_concurrency\": 5})\n",
"# Apply to texts\n", "\n",
"texts = [i.text for i in text_elements]\n", "texts = [i.text for i in text_elements]\n",
"text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})" "text_summaries = summarize_chain.batch(texts, {\"max_concurrency\": 5})"
], ],
@ -2038,17 +2033,12 @@
"from langchain.storage import InMemoryStore\n", "from langchain.storage import InMemoryStore\n",
"from langchain.vectorstores import Chroma\n", "from langchain.vectorstores import Chroma\n",
"\n", "\n",
"# The vectorstore to use to index the child chunks\n",
"vectorstore = Chroma(collection_name=\"summaries\", embedding_function=OpenAIEmbeddings())\n",
"\n",
"# The storage layer for the parent documents\n",
"store = InMemoryStore()\n",
"id_key = \"doc_id\"\n", "id_key = \"doc_id\"\n",
"\n", "\n",
"# The retriever (empty to start)\n", "# The retriever (empty to start)\n",
"retriever = MultiVectorRetriever(\n", "retriever = MultiVectorRetriever(\n",
" vectorstore=vectorstore,\n", " vectorstore=Chroma(collection_name=\"summaries\", embedding_function=OpenAIEmbeddings()),\n",
" docstore=store,\n", " docstore=InMemoryStore(),\n",
" id_key=id_key,\n", " id_key=id_key,\n",
")\n", ")\n",
"\n", "\n",
@ -2081,21 +2071,17 @@
"source": [ "source": [
"from langchain.schema.runnable import RunnablePassthrough\n", "from langchain.schema.runnable import RunnablePassthrough\n",
"\n", "\n",
"# Prompt template\n",
"template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n", "template = \"\"\"Answer the question based only on the following context, which can include text and tables:\n",
"{context}\n", "{context}\n",
"Question: {question}\n", "Question: {question}\n",
"\"\"\"\n", "\"\"\"\n",
"prompt = ChatPromptTemplate.from_template(template)\n", "prompt = ChatPromptTemplate.from_template(template)\n",
"\n", "\n",
"# LLM\n",
"model = ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
"\n",
"# RAG pipeline\n", "# RAG pipeline\n",
"chain = (\n", "chain = (\n",
" {\"context\": retriever, \"question\": RunnablePassthrough()}\n", " {\"context\": retriever, \"question\": RunnablePassthrough()}\n",
" | prompt\n", " | prompt\n",
" | model\n", " | ChatOpenAI(temperature=0, model=\"gpt-4\")\n",
" | StrOutputParser()\n", " | StrOutputParser()\n",
")" ")"
], ],