A strategy that implements whatever it decorates (or is called on) using the LLM.
Source code in llm_strategy/llm_strategy.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92 | def llm_strategy(llm: BaseLLM) -> typing.Callable[[T], T]: # noqa: C901
"""
A strategy that implements whatever it decorates (or is called on) using the LLM.
"""
@typing.no_type_check
def decorator(f: T) -> T:
assert can_wrap_member_in_llm(f)
# For an instance of dataclass, call llm_strategy_dataclass with the fields.
if dataclasses.is_dataclass(f):
if isinstance(f, type):
return llm_dataclass(f, llm)
else:
implemented_dataclass = llm_dataclass(type(f), llm)
# Create an instance of the implemented dataclass using the fields from f
params = {field.name: getattr(f, field.name) for field in dataclasses.fields(f)}
return implemented_dataclass(**params)
else:
def inner_decorator(unwrapped_f):
llm_f = None
@functools.wraps(unwrapped_f)
def strategy_wrapper(*args, **kwargs):
nonlocal llm_f
if llm_f is None:
# Get the signature of f
sig = inspect.signature(unwrapped_f, eval_str=True)
# Add a llm parameter to the signature as first argument
new_params = [inspect.Parameter("__llm", inspect.Parameter.POSITIONAL_ONLY)]
new_params.extend(sig.parameters.values())
new_sig = sig.replace(parameters=new_params)
def dummy_f(*args, **kwargs):
raise NotImplementedError()
new_f = functools.wraps(unwrapped_f)(dummy_f)
new_f.__module__ = unwrapped_f.__module__
# Set the signature of the new function
new_f.__signature__ = new_sig
del new_f.__wrapped__
# Wrap the function in an LLMFunction
llm_f = functools.wraps(new_f)(LLMFunction())
return llm_f(llm, *args, **kwargs)
return strategy_wrapper
return apply_decorator(f, inner_decorator)
return decorator
|