Can't compile the model?

#26
by Luisds95 - opened

Hi! Thanks a lot for this model, seems promising!

I tried to compile it so it runs and trains faster, but I'm getting a mismatch in the location of the tensors at one of the attention layers. Any clues on how to solve this?

Here's a minimal example

import torch
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
model = torch.compile(model)

tokens = model.tokenize(["Hello", "World"])
model(tokens)

Which raises the error
Unhandled FakeTensor Device Propagation for aten._scaled_dot_product_efficient_attention.default, found two different devices cuda:0, cpu

I'm using sentence-transformers==3.4.1 and torch==2.3.1

Here's the full trace in case it helps

TorchRuntimeError                         Traceback (most recent call last)
Cell In[2], line 7
      4 model = torch.compile(model)
      6 tokens = model.tokenize(["Hello", "World"])
----> 7 model(tokens)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:574, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    569 saved_dynamic_layer_stack_depth = (
    570     torch._C._functorch.get_dynamic_layer_stack_depth()
    571 )
    573 try:
--> 574     return fn(*args, **kwargs)
    575 finally:
    576     # Restore the dynamic layer stack depth if necessary.
    577     torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
    578         saved_dynamic_layer_stack_depth
    579     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1380, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
   1374             return hijacked_callback(
   1375                 frame, cache_entry, self.hooks, frame_state
   1376             )
   1378 with compile_lock, _disable_current_modes():
   1379     # skip=1: skip this frame
-> 1380     return self._torchdynamo_orig_callable(
   1381         frame, cache_entry, self.hooks, frame_state, skip=1
   1382     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1164, in ConvertFrame.__call__(self, frame, cache_entry, hooks, frame_state, skip)
   1162 counters["frames"]["total"] += 1
   1163 try:
-> 1164     result = self._inner_convert(
   1165         frame, cache_entry, hooks, frame_state, skip=skip + 1
   1166     )
   1167     counters["frames"]["ok"] += 1
   1168     return result

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:547, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
    544     dynamo_tls.traced_frame_infos.append(info)
    546 with compile_context(CompileContext(compile_id)):
--> 547     return _compile(
    548         frame.f_code,
    549         frame.f_globals,
    550         frame.f_locals,
    551         frame.f_builtins,
    552         frame.closure,
    553         self._torchdynamo_orig_callable,
    554         self._one_graph,
    555         self._export,
    556         self._export_constraints,
    557         hooks,
    558         cache_entry,
    559         cache_size,
    560         frame,
    561         frame_state=frame_state,
    562         compile_id=compile_id,
    563         skip=skip + 1,
    564     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:986, in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
    984 guarded_code = None
    985 try:
--> 986     guarded_code = compile_inner(code, one_graph, hooks, transform)
    988     # NB: We only put_code_state in success case.  Success case here
    989     # does include graph breaks; specifically, if a graph break still
    990     # resulted in a partially compiled graph, we WILL return here.  An
   (...)    995     # to upload for graph break though, because this can prevent
    996     # extra graph break compilations.)
    997     put_code_state()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:715, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
    713     stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
    714     stack.enter_context(CompileTimeInstructionCounter.record())
--> 715     return _compile_inner(code, one_graph, hooks, transform)
    717 return None

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_utils_internal.py:95, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
     92     kwargs["skip"] = skip + 1
     94 if not StrobelightCompileTimeProfiler.enabled:
---> 95     return function(*args, **kwargs)
     97 return StrobelightCompileTimeProfiler.profile_compile_time(
     98     function, phase_name, *args, **kwargs
     99 )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:750, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
    748 CompileContext.get().attempt = attempt
    749 try:
--> 750     out_code = transform_code_object(code, transform)
    751     break
    752 except exc.RestartAnalysis as e:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1361, in transform_code_object(code, transformations, safe)
   1358 instructions = cleaned_instructions(code, safe)
   1359 propagate_line_nums(instructions)
-> 1361 transformations(instructions, code_options)
   1362 return clean_and_assemble_instructions(instructions, keys, code_options)[1]

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:231, in preserve_global_state.<locals>._fn(*args, **kwargs)
    229 exit_stack.enter_context(torch_function_mode_stack_state_mgr)
    230 try:
--> 231     return fn(*args, **kwargs)
    232 finally:
    233     cleanup.close()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:662, in _compile.<locals>.transform(instructions, code_options)
    660 try:
    661     with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 662         tracer.run()
    663 except exc.UnspecializeRestartAnalysis:
    664     speculation_log.clear()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2868, in InstructionTranslator.run(self)
   2867 def run(self):
-> 2868     super().run()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
   1050 try:
   1051     self.output.push_tx(self)
-> 1052     while self.step():
   1053         pass
   1054 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
    959 self.update_block_stack(inst)
    961 try:
--> 962     self.dispatch_table[inst.opcode](self, inst)
    963     return not self.output.should_exit
    964 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    657     return handle_graph_break(self, inst, speculation.reason)
    658 try:
--> 659     return inner_fn(self, inst)
    660 except Unsupported as excp:
    661     if self.generic_context_manager_depth > 0:
    662         # We don't support graph break under GenericContextWrappingVariable,
    663         # If there is, we roll back to the checkpoint and fall back.

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1736, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
   1734 # Map to a dictionary of str -> VariableTracker
   1735 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1736 self.call_function(fn, argsvars.items, kwargsvars)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    896     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
    906 ctx = (
    907     record_nn_module_stack(
    908         str(id(mod)), self.get_nn_module_stack_source(), tx, mod
   (...)    911     else nullcontext()
    912 )
    913 with ctx:
--> 914     return variables.UserFunctionVariable(fn, source=source).call_function(
    915         tx, [self] + list(args), kwargs
    916     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    315         with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
    316             return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    112 def call_function(
    113     self,
    114     tx: "InstructionTranslator",
    115     args: "List[VariableTracker]",
    116     kwargs: "Dict[str, VariableTracker]",
    117 ) -> "VariableTracker":
--> 118     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    899 def inline_user_function_return(self, fn, args, kwargs):
    900     """
    901     A call to some user defined function by inlining it.
    902     """
--> 903     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3069 @classmethod
   3070 def inline_call(cls, parent, func, args, kwargs):
   3071     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072         return cls.inline_call_(parent, func, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3196 try:
   3197     with strict_ctx:
-> 3198         tracer.run()
   3199 except exc.ObservedException as e:
   3200     msg = f"Observed exception DURING INLING {code} : {e}"

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
   1050 try:
   1051     self.output.push_tx(self)
-> 1052     while self.step():
   1053         pass
   1054 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
    959 self.update_block_stack(inst)
    961 try:
--> 962     self.dispatch_table[inst.opcode](self, inst)
    963     return not self.output.should_exit
    964 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    657     return handle_graph_break(self, inst, speculation.reason)
    658 try:
--> 659     return inner_fn(self, inst)
    660 except Unsupported as excp:
    661     if self.generic_context_manager_depth > 0:
    662         # We don't support graph break under GenericContextWrappingVariable,
    663         # If there is, we roll back to the checkpoint and fall back.

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1736, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
   1734 # Map to a dictionary of str -> VariableTracker
   1735 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1736 self.call_function(fn, argsvars.items, kwargsvars)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    896     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:170, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
    166 @functools.wraps(getattr(VariableTracker, name))
    167 def realize_and_forward(
    168     self: LazyVariableTracker, *args: Any, **kwargs: Any
    169 ) -> Any:
--> 170     return getattr(self.realize(), name)(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
    906 ctx = (
    907     record_nn_module_stack(
    908         str(id(mod)), self.get_nn_module_stack_source(), tx, mod
   (...)    911     else nullcontext()
    912 )
    913 with ctx:
--> 914     return variables.UserFunctionVariable(fn, source=source).call_function(
    915         tx, [self] + list(args), kwargs
    916     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    315         with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
    316             return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    112 def call_function(
    113     self,
    114     tx: "InstructionTranslator",
    115     args: "List[VariableTracker]",
    116     kwargs: "Dict[str, VariableTracker]",
    117 ) -> "VariableTracker":
--> 118     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    899 def inline_user_function_return(self, fn, args, kwargs):
    900     """
    901     A call to some user defined function by inlining it.
    902     """
--> 903     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3069 @classmethod
   3070 def inline_call(cls, parent, func, args, kwargs):
   3071     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072         return cls.inline_call_(parent, func, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3196 try:
   3197     with strict_ctx:
-> 3198         tracer.run()
   3199 except exc.ObservedException as e:
   3200     msg = f"Observed exception DURING INLING {code} : {e}"

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
   1050 try:
   1051     self.output.push_tx(self)
-> 1052     while self.step():
   1053         pass
   1054 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
    959 self.update_block_stack(inst)
    961 try:
--> 962     self.dispatch_table[inst.opcode](self, inst)
    963     return not self.output.should_exit
    964 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    657     return handle_graph_break(self, inst, speculation.reason)
    658 try:
--> 659     return inner_fn(self, inst)
    660 except Unsupported as excp:
    661     if self.generic_context_manager_depth > 0:
    662         # We don't support graph break under GenericContextWrappingVariable,
    663         # If there is, we roll back to the checkpoint and fall back.

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2341, in InstructionTranslatorBase.CALL(self, inst)
   2339 @break_graph_if_unsupported(push=1)
   2340 def CALL(self, inst):
-> 2341     self._call(inst)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2335, in InstructionTranslatorBase._call(self, inst, call_kw)
   2330     kwargs = {}
   2332 try:
   2333     # if call_function fails, need to set kw_names to None, otherwise
   2334     # a subsequent call may have self.kw_names set to an old value
-> 2335     self.call_function(fn, args, kwargs)
   2336 finally:
   2337     self.kw_names = None

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    896     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:170, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
    166 @functools.wraps(getattr(VariableTracker, name))
    167 def realize_and_forward(
    168     self: LazyVariableTracker, *args: Any, **kwargs: Any
    169 ) -> Any:
--> 170     return getattr(self.realize(), name)(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
    906 ctx = (
    907     record_nn_module_stack(
    908         str(id(mod)), self.get_nn_module_stack_source(), tx, mod
   (...)    911     else nullcontext()
    912 )
    913 with ctx:
--> 914     return variables.UserFunctionVariable(fn, source=source).call_function(
    915         tx, [self] + list(args), kwargs
    916     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    315         with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
    316             return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    112 def call_function(
    113     self,
    114     tx: "InstructionTranslator",
    115     args: "List[VariableTracker]",
    116     kwargs: "Dict[str, VariableTracker]",
    117 ) -> "VariableTracker":
--> 118     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    899 def inline_user_function_return(self, fn, args, kwargs):
    900     """
    901     A call to some user defined function by inlining it.
    902     """
--> 903     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3069 @classmethod
   3070 def inline_call(cls, parent, func, args, kwargs):
   3071     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072         return cls.inline_call_(parent, func, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3196 try:
   3197     with strict_ctx:
-> 3198         tracer.run()
   3199 except exc.ObservedException as e:
   3200     msg = f"Observed exception DURING INLING {code} : {e}"

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
   1050 try:
   1051     self.output.push_tx(self)
-> 1052     while self.step():
   1053         pass
   1054 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
    959 self.update_block_stack(inst)
    961 try:
--> 962     self.dispatch_table[inst.opcode](self, inst)
    963     return not self.output.should_exit
    964 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    657     return handle_graph_break(self, inst, speculation.reason)
    658 try:
--> 659     return inner_fn(self, inst)
    660 except Unsupported as excp:
    661     if self.generic_context_manager_depth > 0:
    662         # We don't support graph break under GenericContextWrappingVariable,
    663         # If there is, we roll back to the checkpoint and fall back.

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2341, in InstructionTranslatorBase.CALL(self, inst)
   2339 @break_graph_if_unsupported(push=1)
   2340 def CALL(self, inst):
-> 2341     self._call(inst)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2335, in InstructionTranslatorBase._call(self, inst, call_kw)
   2330     kwargs = {}
   2332 try:
   2333     # if call_function fails, need to set kw_names to None, otherwise
   2334     # a subsequent call may have self.kw_names set to an old value
-> 2335     self.call_function(fn, args, kwargs)
   2336 finally:
   2337     self.kw_names = None

    [... skipping similar frames: InstructionTranslatorBase.call_function at line 897 (3 times), InstructionTranslatorBase.CALL at line 2341 (2 times), InstructionTranslatorBase._call at line 2335 (2 times), UnspecializedNNModuleVariable.call_function at line 914 (2 times), UserFunctionVariable.call_function at line 317 (2 times), BaseUserFunctionVariable.call_function at line 118 (2 times), InliningInstructionTranslator.inline_call at line 3072 (2 times), InliningInstructionTranslator.inline_call_ at line 3198 (2 times), InstructionTranslatorBase.inline_user_function_return at line 903 (2 times), InstructionTranslatorBase.run at line 1052 (2 times), InstructionTranslatorBase.step at line 962 (2 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 659 (2 times), _create_realize_and_forward.<locals>.realize_and_forward at line 170 (1 times)]

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:170, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
    166 @functools.wraps(getattr(VariableTracker, name))
    167 def realize_and_forward(
    168     self: LazyVariableTracker, *args: Any, **kwargs: Any
    169 ) -> Any:
--> 170     return getattr(self.realize(), name)(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
    906 ctx = (
    907     record_nn_module_stack(
    908         str(id(mod)), self.get_nn_module_stack_source(), tx, mod
   (...)    911     else nullcontext()
    912 )
    913 with ctx:
--> 914     return variables.UserFunctionVariable(fn, source=source).call_function(
    915         tx, [self] + list(args), kwargs
    916     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
    315         with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
    316             return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
    112 def call_function(
    113     self,
    114     tx: "InstructionTranslator",
    115     args: "List[VariableTracker]",
    116     kwargs: "Dict[str, VariableTracker]",
    117 ) -> "VariableTracker":
--> 118     return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
    899 def inline_user_function_return(self, fn, args, kwargs):
    900     """
    901     A call to some user defined function by inlining it.
    902     """
--> 903     return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
   3069 @classmethod
   3070 def inline_call(cls, parent, func, args, kwargs):
   3071     with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072         return cls.inline_call_(parent, func, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
   3196 try:
   3197     with strict_ctx:
-> 3198         tracer.run()
   3199 except exc.ObservedException as e:
   3200     msg = f"Observed exception DURING INLING {code} : {e}"

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
   1050 try:
   1051     self.output.push_tx(self)
-> 1052     while self.step():
   1053         pass
   1054 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
    959 self.update_block_stack(inst)
    961 try:
--> 962     self.dispatch_table[inst.opcode](self, inst)
    963     return not self.output.should_exit
    964 except TensorifyScalarRestartAnalysis:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
    657     return handle_graph_break(self, inst, speculation.reason)
    658 try:
--> 659     return inner_fn(self, inst)
    660 except Unsupported as excp:
    661     if self.generic_context_manager_depth > 0:
    662         # We don't support graph break under GenericContextWrappingVariable,
    663         # If there is, we roll back to the checkpoint and fall back.

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2341, in InstructionTranslatorBase.CALL(self, inst)
   2339 @break_graph_if_unsupported(push=1)
   2340 def CALL(self, inst):
-> 2341     self._call(inst)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2335, in InstructionTranslatorBase._call(self, inst, call_kw)
   2330     kwargs = {}
   2332 try:
   2333     # if call_function fails, need to set kw_names to None, otherwise
   2334     # a subsequent call may have self.kw_names set to an old value
-> 2335     self.call_function(fn, args, kwargs)
   2336 finally:
   2337     self.kw_names = None

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
    895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
    896     raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py:953, in TorchInGraphFunctionVariable.call_function(self, tx, args, kwargs)
    944         if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
    945             # Calling fake tensor propagation can mutate the out= tensor in
    946             # tx.output.tracked_fakes. tracked_fakes are used to apply
   (...)    949             # guards. So save the shape now, and check later if it has
    950             # changed. If it has, graph break.
    951             fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
--> 953         tensor_variable = wrap_fx_proxy(
    954             tx=tx,
    955             proxy=tx.output.create_proxy(
    956                 "call_function",
    957                 fn_,
    958                 *proxy_args_kwargs(args, kwargs),
    959             ),
    960         )
    962         if (
    963             isinstance(tensor_variable, TensorVariable)
    964             and "requires_grad" in kwargs
    965             and kwargs["requires_grad"].as_python_constant()
    966         ):
    967             unimplemented(
    968                 """factory functions that return tensors that require grad are not supported.
    969 Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
    970             )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2153, in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
   2145 kwargs = {
   2146     "tx": tx,
   2147     "proxy": proxy,
   (...)   2150     **options,
   2151 }
   2152 if subclass_type is None:
-> 2153     return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
   2154 else:
   2155     result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2219, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
   2215 def wrap_fx_proxy_cls(
   2216     target_cls, tx, proxy, example_value=None, subclass_type=None, **options
   2217 ):
   2218     if example_value is None:
-> 2219         return _wrap_fx_proxy(
   2220             target_cls, tx, proxy, example_value, subclass_type, **options
   2221         )
   2222     elif isinstance(example_value, torch.Tensor):
   2223         return _wrap_fx_preexisting_tensor(
   2224             target_cls, tx, proxy, example_value, subclass_type, **options
   2225         )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2315, in _wrap_fx_proxy(target_cls, tx, proxy, example_value, subclass_type, **options)
   2310 # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
   2311 with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
   2312     # with preserve_rng_state():
   2313     # only allow_non_graph_fake in this instance because we handle the non-fake
   2314     # cases properly below.
-> 2315     example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
   2317 return handle_traced_output(
   2318     example_value, tx, proxy, options, subclass_type, target_cls
   2319 )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2536, in get_fake_value(node, tx, allow_non_graph_fake)
   2533     elif isinstance(cause, TypeError) and "argument" in str(cause):
   2534         unimplemented(f"TypeError {node.target}: {cause}")
-> 2536     raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
   2538 if not allow_non_graph_fake:
   2539     _ = pytree.tree_map_only(
   2540         torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
   2541     )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2471, in get_fake_value(node, tx, allow_non_graph_fake)
   2469 try:
   2470     with tx.fake_mode, enable_python_dispatcher():
-> 2471         ret_val = wrap_fake_exception(
   2472             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   2473         )
   2474 except Unsupported:
   2475     raise

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2017, in wrap_fake_exception(fn)
   2015 def wrap_fake_exception(fn):
   2016     try:
-> 2017         return fn()
   2018     except UnsupportedFakeTensorException as e:
   2019         from .exc import unimplemented

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2472, in get_fake_value.<locals>.<lambda>()
   2469 try:
   2470     with tx.fake_mode, enable_python_dispatcher():
   2471         ret_val = wrap_fake_exception(
-> 2472             lambda: run_node(tx.output, node, args, kwargs, nnmodule)
   2473         )
   2474 except Unsupported:
   2475     raise

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2604, in run_node(tracer, node, args, kwargs, nnmodule)
   2602         unimplemented(make_error_message(e), from_exc=e)
   2603     except Exception as e:
-> 2604         raise RuntimeError(make_error_message(e)).with_traceback(
   2605             e.__traceback__
   2606         ) from e
   2608 raise AssertionError(op)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2586, in run_node(tracer, node, args, kwargs, nnmodule)
   2584 try:
   2585     if op == "call_function":
-> 2586         return node.target(*args, **kwargs)
   2587     elif op == "call_method":
   2588         return getattr(args[0], node.target)(*args[1:], **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/utils/_stats.py:21, in count.<locals>.wrapper(*args, **kwargs)
     19     simple_call_counter[fn.__qualname__] = 0
     20 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 21 return fn(*args, **kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1276, in FakeTensorMode.__torch_dispatch__(self, func, types, args, kwargs)
   1272 assert (
   1273     torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None
   1274 ), func
   1275 try:
-> 1276     return self.dispatch(func, types, args, kwargs)
   1277 except TypeError:
   1278     log.exception("fake tensor raised TypeError")

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1816, in FakeTensorMode.dispatch(self, func, types, args, kwargs)
   1813         return func(*args, **kwargs)
   1815 if self.cache_enabled:
-> 1816     return self._cached_dispatch_impl(func, types, args, kwargs)
   1817 else:
   1818     return self._dispatch_impl(func, types, args, kwargs)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1377, in FakeTensorMode._cached_dispatch_impl(self, func, types, args, kwargs)
   1375 else:
   1376     self._validate_cache_key(func, args, kwargs)
-> 1377     output = self._dispatch_impl(func, types, args, kwargs)
   1378     entry = self._make_cache_entry(state, key, func, args, kwargs, output)
   1379     key.strip_shape_env()

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2392, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
   2388     log.exception("failed while attempting to run meta for %s", func)
   2389     raise
   2391 return maybe_propagate_real_tensors(
-> 2392     self.wrap_meta_outputs_with_default_device_logic(
   2393         r, func, flat_args, device=kwargs.get("device")
   2394     )
   2395 )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2514, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device)
   2511     else:
   2512         return e
-> 2514 return tree_map(wrap, r)

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/utils/_pytree.py:991, in tree_map(func, tree, is_leaf, *rests)
    989 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
    990 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 991 return treespec.unflatten(map(func, *flat_args))

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/utils/_pytree.py:830, in TreeSpec.unflatten(self, leaves)
    828 def unflatten(self, leaves: Iterable[Any]) -> PyTree:
    829     if not isinstance(leaves, (list, tuple)):
--> 830         leaves = list(leaves)
    831     if len(leaves) != self.num_leaves:
    832         raise ValueError(
    833             f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
    834             f"but the spec refers to a pytree that holds {self.num_leaves} "
    835             f"items ({self}).",
    836         )

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2492, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic.<locals>.wrap(e)
   2486     return e
   2488 if common_device is None:
   2489     (
   2490         common_device,
   2491         has_scalar_only_inputs,
-> 2492     ) = FakeTensor._find_common_device(func, flat_args)
   2494 is_our_fake = self.is_our_fake(e)
   2495 if is_our_fake:

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:890, in FakeTensor._find_common_device(func, flat_args)
    885     raise RuntimeError(
    886         f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
    887     )
    889 for arg in flat_args:
--> 890     merge_devices(arg)
    892 # some functions that allow Python numbers to bind to Tensors
    893 # if we have failed to find a device, and we're running one of these operators,
    894 # we must have scalar only inputs
    895 if should_allow_numbers_as_tensors(func) and common_device is None:
    896     # ops with scalar only inputs always have result on cpu

File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:885, in FakeTensor._find_common_device.<locals>.merge_devices(t)
    881     return
    883 # mismatching devices of non-zero dim tensors, throw
    884 # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
--> 885 raise RuntimeError(
    886     f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
    887 )

TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(2, 16, 3, 64),
           grad_fn=<PermuteBackward0>), FakeTensor(..., device='cuda:0', size=(2, 16, 3, 64),
           grad_fn=<PermuteBackward0>), FakeTensor(..., device='cuda:0', size=(2, 16, 3, 64),
           grad_fn=<PermuteBackward0>)), **{'attn_mask': FakeTensor(..., size=(2, 1, 3, 3)), 'dropout_p': 0.0, 'is_causal': False}):
Unhandled FakeTensor Device Propagation for aten._scaled_dot_product_efficient_attention.default, found two different devices cuda:0, cpu
Your need to confirm your account before you can post a new comment.

Sign up or log in to comment