Skip to content

chat_chain

ChatChain dataclass

Source code in llm_strategy/chat_chain.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
@dataclass
class ChatChain:
    chat_model: BaseChatModel
    messages: list[BaseMessage]

    @property
    def response(self) -> str:
        assert len(self.messages) >= 1
        return cast(str, self.messages[-1].content)

    def append(self, messages: list[BaseMessage]) -> "ChatChain":
        return dataclasses.replace(self, messages=self.messages + messages)

    def __add__(self, other: list[BaseMessage]) -> "ChatChain":
        return self.append(other)

    def query(self, question: str, model_args: dict | None = None) -> Tuple[str, "ChatChain"]:
        """Asks a question and returns the result in a single block."""
        # Build messages:
        messages = self.messages + [HumanMessage(content=question)]
        model_args = model_args or {}
        reply = self.chat_model.invoke(messages, **model_args)
        messages.append(reply)
        return cast(str, reply.content), dataclasses.replace(self, messages=messages)

    def enforce_json_response(self, model_args: dict | None = None) -> dict:
        model_args = model_args or {}
        # Check if the language model is of type "openai" and extend model args with a response format in that case
        model_dict = self.chat_model.dict()
        if "openai" in model_dict["_type"] and model_dict.get("model_name") in (
            "gpt-4-1106-preview",
            "gpt-3.5-turbo-1106",
        ):
            model_args = {**model_args, "response_format": dict(type="json_object")}
        return model_args

    def structured_query(
        self, question: str, return_type: type[B], model_args: dict | None = None
    ) -> Tuple[B, "ChatChain"]:
        """Asks a question and returns the result in a single block."""
        # Build messages:

        if typing.get_origin(return_type) is typing.Annotated:
            return_info = typing.get_args(return_type)
        else:
            return_info = (return_type, ...)

        output_model = create_model("StructuredOutput", result=return_info)
        parser: PydanticOutputParser = PydanticOutputParser(pydantic_object=output_model)
        question_and_formatting = question + "\n\n" + parser.get_format_instructions()

        reply_content, chain = self.query(question_and_formatting, **self.enforce_json_response(model_args))
        parsed_reply: B = typing.cast(B, parser.parse(reply_content))

        return parsed_reply, chain

    def branch(self) -> "ChatChain":
        return dataclasses.replace(self, messages=self.messages.copy())

query(question, model_args=None)

Asks a question and returns the result in a single block.

Source code in llm_strategy/chat_chain.py
31
32
33
34
35
36
37
38
def query(self, question: str, model_args: dict | None = None) -> Tuple[str, "ChatChain"]:
    """Asks a question and returns the result in a single block."""
    # Build messages:
    messages = self.messages + [HumanMessage(content=question)]
    model_args = model_args or {}
    reply = self.chat_model.invoke(messages, **model_args)
    messages.append(reply)
    return cast(str, reply.content), dataclasses.replace(self, messages=messages)

structured_query(question, return_type, model_args=None)

Asks a question and returns the result in a single block.

Source code in llm_strategy/chat_chain.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def structured_query(
    self, question: str, return_type: type[B], model_args: dict | None = None
) -> Tuple[B, "ChatChain"]:
    """Asks a question and returns the result in a single block."""
    # Build messages:

    if typing.get_origin(return_type) is typing.Annotated:
        return_info = typing.get_args(return_type)
    else:
        return_info = (return_type, ...)

    output_model = create_model("StructuredOutput", result=return_info)
    parser: PydanticOutputParser = PydanticOutputParser(pydantic_object=output_model)
    question_and_formatting = question + "\n\n" + parser.get_format_instructions()

    reply_content, chain = self.query(question_and_formatting, **self.enforce_json_response(model_args))
    parsed_reply: B = typing.cast(B, parser.parse(reply_content))

    return parsed_reply, chain