Coverage for skema/rest/tests/test_llms.py: 100%
21 statements
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
« prev ^ index » next coverage.py v7.5.0, created at 2024-04-30 17:15 +0000
1from langchain.chat_models import ChatOpenAI
2from langchain.prompts import (
3 ChatPromptTemplate,
4 SystemMessagePromptTemplate,
5 HumanMessagePromptTemplate,
6)
7from langchain.output_parsers import (
8 StructuredOutputParser,
9 ResponseSchema
10)
11import langchain.schema
12from skema.rest.proxies import SKEMA_OPENAI_KEY
14def test_prompt_construction():
15 """Tests prompt template instantiation"""
16 # TODO: your assertion here that the template instantiation returns a string/valid type
18 code = "def sir(\n s: float, i: float, r: float, beta: float, gamma: float, n: float\n) -> Tuple[float, float, float]:\n \"\"\"The SIR model, one time step.\"\"\"\n s_n = (-beta * s * i) + s\n i_n = (beta * s * i - gamma * i) + i\n r_n = gamma * i + r\n scale = n / (s_n + i_n + r_n)\n return s_n * scale, i_n * scale, r_n * scale"
20 # this is the formatting instructions
21 response_schemas = [
22 ResponseSchema(name="model_function", description="The name of the function that contains the model dynamics")
23 ]
25 # for structured output parsing, converts schema to langhchain object
26 output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
28 # for structured output parsing, makes the instructions to be passed as a variable to prompt template
29 format_instructions = output_parser.get_format_instructions()
31 # construct the prompts
32 template="You are a assistant that answers questions about code."
33 system_message_prompt = SystemMessagePromptTemplate.from_template(template)
34 human_template="Find the function that contains the model dynamics in {code} \n{format_instructions}"
35 human_message_prompt = HumanMessagePromptTemplate.from_template(human_template)
37 # combining the templates for a chat template
38 chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt])
40 # formatting the prompt with input variables
41 formatted_prompt = chat_prompt.format_prompt(code=code, format_instructions = format_instructions).to_messages()
43 # mockes the output from the model
44 output_mock = langchain.schema.messages.AIMessage(content='```json\n{\n\t"model_function": "sir"\n}\n```',additional_kwargs={}, example=False )
46 parsed_output = output_parser.parse(output_mock.content)
48 assert isinstance(parsed_output['model_function'], str)
49 assert isinstance(formatted_prompt[0], langchain.schema.messages.SystemMessage)
50 assert isinstance(formatted_prompt[1], langchain.schema.messages.HumanMessage)