Coverage for skema/rest/tests/test_llms.py: 100%

21 statements  

« prev     ^ index     » next       coverage.py v7.5.0, created at 2024-04-30 17:15 +0000

1from langchain.chat_models import ChatOpenAI 

2from langchain.prompts import ( 

3 ChatPromptTemplate, 

4 SystemMessagePromptTemplate, 

5 HumanMessagePromptTemplate, 

6) 

7from langchain.output_parsers import ( 

8 StructuredOutputParser, 

9 ResponseSchema 

10) 

11import langchain.schema 

12from skema.rest.proxies import SKEMA_OPENAI_KEY 

13 

14def test_prompt_construction(): 

15 """Tests prompt template instantiation""" 

16 # TODO: your assertion here that the template instantiation returns a string/valid type 

17 

18 code = "def sir(\n s: float, i: float, r: float, beta: float, gamma: float, n: float\n) -> Tuple[float, float, float]:\n \"\"\"The SIR model, one time step.\"\"\"\n s_n = (-beta * s * i) + s\n i_n = (beta * s * i - gamma * i) + i\n r_n = gamma * i + r\n scale = n / (s_n + i_n + r_n)\n return s_n * scale, i_n * scale, r_n * scale" 

19 

20 # this is the formatting instructions 

21 response_schemas = [ 

22 ResponseSchema(name="model_function", description="The name of the function that contains the model dynamics") 

23 ] 

24 

25 # for structured output parsing, converts schema to langhchain object 

26 output_parser = StructuredOutputParser.from_response_schemas(response_schemas) 

27 

28 # for structured output parsing, makes the instructions to be passed as a variable to prompt template 

29 format_instructions = output_parser.get_format_instructions() 

30 

31 # construct the prompts 

32 template="You are a assistant that answers questions about code." 

33 system_message_prompt = SystemMessagePromptTemplate.from_template(template) 

34 human_template="Find the function that contains the model dynamics in {code} \n{format_instructions}" 

35 human_message_prompt = HumanMessagePromptTemplate.from_template(human_template) 

36 

37 # combining the templates for a chat template 

38 chat_prompt = ChatPromptTemplate.from_messages([system_message_prompt, human_message_prompt]) 

39 

40 # formatting the prompt with input variables 

41 formatted_prompt = chat_prompt.format_prompt(code=code, format_instructions = format_instructions).to_messages() 

42 

43 # mockes the output from the model 

44 output_mock = langchain.schema.messages.AIMessage(content='```json\n{\n\t"model_function": "sir"\n}\n```',additional_kwargs={}, example=False ) 

45 

46 parsed_output = output_parser.parse(output_mock.content) 

47 

48 assert isinstance(parsed_output['model_function'], str) 

49 assert isinstance(formatted_prompt[0], langchain.schema.messages.SystemMessage) 

50 assert isinstance(formatted_prompt[1], langchain.schema.messages.HumanMessage) 

51