diff --git a/server/agent/tools_factory/calculate.py b/server/agent/tools_factory/calculate.py index e66292f6..880aa5c8 100644 --- a/server/agent/tools_factory/calculate.py +++ b/server/agent/tools_factory/calculate.py @@ -1,23 +1,15 @@ -from server.pydantic_types import BaseModel, Field +from langchain.agents import tool -def calculate(a: float, b: float, operator: str) -> float: - if operator == "+": - return a + b - elif operator == "-": - return a - b - elif operator == "*": - return a * b - elif operator == "/": - if b != 0: - return a / b - else: - return float('inf') - elif operator == "^": - return a ** b - else: - raise ValueError("Unsupported operator") -class CalculatorInput(BaseModel): - a: float = Field(description="first number") - b: float = Field(description="second number") - operator: str = Field(description="operator to use (e.g., +, -, *, /, ^)") +@tool +def calculate(text: str) -> float: + ''' + Useful to answer questions about simple calculations. + translate user question to a math expression that can be evaluated by numexpr. + ''' + import numexpr + + try: + return str(numexpr.evaluate(text)) + except Exception as e: + return f"wrong: {e}" diff --git a/server/agent/tools_factory/tools_registry.py b/server/agent/tools_factory/tools_registry.py index 9fb6ee1f..7934adc7 100644 --- a/server/agent/tools_factory/tools_registry.py +++ b/server/agent/tools_factory/tools_registry.py @@ -7,12 +7,7 @@ KB_info_str = '\n'.join([f"{key}: {value}" for key, value in KB_INFO.items()]) template_knowledge = template.format(KB_info=KB_info_str, key="samples") all_tools = [ - StructuredTool.from_function( - func=calculate, - name="calculate", - description="Useful for when you need to answer questions about simple calculations", - args_schema=CalculatorInput, - ), + calculate, StructuredTool.from_function( func=arxiv, name="arxiv",