from pydantic import BaseModel, Field
from pydantic_ai import Agent

from dotenv import load_dotenv

load_dotenv('.env')


class Suggestions(BaseModel):
    """Suggestion for what the user might want to do next. Suggest things based on previous action and what you can do for them."""

    suggestion: str = Field(..., description='Should be very short (a few words)')


agent1 = Agent(
    'openai:gpt-4o-mini',
    instructions='You are a first helpful assistant. Whatever you reply, include the word [A1]',
)
agent2 = Agent(
    'openai:gpt-4o-mini',
    instructions='You are a second helpful assistant. Whatever you reply, include the word [A2]',
    output_type=Suggestions,
)

import logfire

# configure logfire
logfire.configure(send_to_logfire='if-token-present')
logfire.instrument_pydantic_ai()


class LinearWorkflow:
    """A workflow that runs a sequence of agents linearly, passing the output of one as input to the next."""

    def __init__(self, agents) -> None:
        self._agents = agents

    async def run(
        self,
        *,
        message_history,
    ):
        with logfire.span('LinearWorkflow.run'):
            current_history = [*message_history] or []  # Don't mutate input list

            for agent in self._agents:
                result = await agent.run(
                    message_history=current_history,
                )
                current_history.extend(result.new_messages())

            return result


async def bug_multi_agent() -> Agent[None, None]:
    workflow = LinearWorkflow([agent1, agent2])
    result = await workflow.run(
        message_history=[
            {'role': 'user', 'content': 'Hello, how are you?'},
        ],
    )
    print('Final output:', result.output)
    print('\nAgent2 response should contain [A2], not [A1]')

    # Check the actual response
    if hasattr(result.output, 'suggestion'):
        suggestion_text = result.output.suggestion
        print(f'\nSuggestion text: {suggestion_text}')
        if '[A2]' in suggestion_text:
            print("✓ CORRECT: Agent2's instructions were used")
        elif '[A1]' in suggestion_text:
            print("✗ BUG CONFIRMED: Agent1's instructions leaked into Agent2")
        else:
            print('? UNCLEAR: Neither marker found in response')


if __name__ == '__main__':
    import asyncio

    asyncio.run(bug_multi_agent())
