Can't compile the model?
#26
by
Luisds95
- opened
Hi! Thanks a lot for this model, seems promising!
I tried to compile it so it runs and trains faster, but I'm getting a mismatch in the location of the tensors at one of the attention layers. Any clues on how to solve this?
Here's a minimal example
import torch
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
model = torch.compile(model)
tokens = model.tokenize(["Hello", "World"])
model(tokens)
Which raises the errorUnhandled FakeTensor Device Propagation for aten._scaled_dot_product_efficient_attention.default, found two different devices cuda:0, cpu
I'm using sentence-transformers==3.4.1
and torch==2.3.1
Here's the full trace in case it helps
TorchRuntimeError Traceback (most recent call last)
Cell In[2], line 7
4 model = torch.compile(model)
6 tokens = model.tokenize(["Hello", "World"])
----> 7 model(tokens)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py:574, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
569 saved_dynamic_layer_stack_depth = (
570 torch._C._functorch.get_dynamic_layer_stack_depth()
571 )
573 try:
--> 574 return fn(*args, **kwargs)
575 finally:
576 # Restore the dynamic layer stack depth if necessary.
577 torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth(
578 saved_dynamic_layer_stack_depth
579 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
1737 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1738 else:
-> 1739 return self._call_impl(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
1745 # If we don't have any hooks, we want to skip the rest of the logic in
1746 # this function, and just call forward.
1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1748 or _global_backward_pre_hooks or _global_backward_hooks
1749 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750 return forward_call(*args, **kwargs)
1752 result = None
1753 called_always_called_hooks = set()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1380, in CatchErrorsWrapper.__call__(self, frame, cache_entry, frame_state)
1374 return hijacked_callback(
1375 frame, cache_entry, self.hooks, frame_state
1376 )
1378 with compile_lock, _disable_current_modes():
1379 # skip=1: skip this frame
-> 1380 return self._torchdynamo_orig_callable(
1381 frame, cache_entry, self.hooks, frame_state, skip=1
1382 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:1164, in ConvertFrame.__call__(self, frame, cache_entry, hooks, frame_state, skip)
1162 counters["frames"]["total"] += 1
1163 try:
-> 1164 result = self._inner_convert(
1165 frame, cache_entry, hooks, frame_state, skip=skip + 1
1166 )
1167 counters["frames"]["ok"] += 1
1168 return result
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:547, in ConvertFrameAssert.__call__(self, frame, cache_entry, hooks, frame_state, skip)
544 dynamo_tls.traced_frame_infos.append(info)
546 with compile_context(CompileContext(compile_id)):
--> 547 return _compile(
548 frame.f_code,
549 frame.f_globals,
550 frame.f_locals,
551 frame.f_builtins,
552 frame.closure,
553 self._torchdynamo_orig_callable,
554 self._one_graph,
555 self._export,
556 self._export_constraints,
557 hooks,
558 cache_entry,
559 cache_size,
560 frame,
561 frame_state=frame_state,
562 compile_id=compile_id,
563 skip=skip + 1,
564 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:986, in _compile(code, globals, locals, builtins, closure, compiler_fn, one_graph, export, export_constraints, hooks, cache_entry, cache_size, frame, frame_state, compile_id, skip)
984 guarded_code = None
985 try:
--> 986 guarded_code = compile_inner(code, one_graph, hooks, transform)
988 # NB: We only put_code_state in success case. Success case here
989 # does include graph breaks; specifically, if a graph break still
990 # resulted in a partially compiled graph, we WILL return here. An
(...) 995 # to upload for graph break though, because this can prevent
996 # extra graph break compilations.)
997 put_code_state()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:715, in _compile.<locals>.compile_inner(code, one_graph, hooks, transform)
713 stack.enter_context(torch._dynamo.callback_handler.install_callbacks())
714 stack.enter_context(CompileTimeInstructionCounter.record())
--> 715 return _compile_inner(code, one_graph, hooks, transform)
717 return None
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_utils_internal.py:95, in compile_time_strobelight_meta.<locals>.compile_time_strobelight_meta_inner.<locals>.wrapper_function(*args, **kwargs)
92 kwargs["skip"] = skip + 1
94 if not StrobelightCompileTimeProfiler.enabled:
---> 95 return function(*args, **kwargs)
97 return StrobelightCompileTimeProfiler.profile_compile_time(
98 function, phase_name, *args, **kwargs
99 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:750, in _compile.<locals>._compile_inner(code, one_graph, hooks, transform)
748 CompileContext.get().attempt = attempt
749 try:
--> 750 out_code = transform_code_object(code, transform)
751 break
752 except exc.RestartAnalysis as e:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/bytecode_transformation.py:1361, in transform_code_object(code, transformations, safe)
1358 instructions = cleaned_instructions(code, safe)
1359 propagate_line_nums(instructions)
-> 1361 transformations(instructions, code_options)
1362 return clean_and_assemble_instructions(instructions, keys, code_options)[1]
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:231, in preserve_global_state.<locals>._fn(*args, **kwargs)
229 exit_stack.enter_context(torch_function_mode_stack_state_mgr)
230 try:
--> 231 return fn(*args, **kwargs)
232 finally:
233 cleanup.close()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/convert_frame.py:662, in _compile.<locals>.transform(instructions, code_options)
660 try:
661 with tracing(tracer.output.tracing_context), tracer.set_current_tx():
--> 662 tracer.run()
663 except exc.UnspecializeRestartAnalysis:
664 speculation_log.clear()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2868, in InstructionTranslator.run(self)
2867 def run(self):
-> 2868 super().run()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
1050 try:
1051 self.output.push_tx(self)
-> 1052 while self.step():
1053 pass
1054 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
959 self.update_block_stack(inst)
961 try:
--> 962 self.dispatch_table[inst.opcode](self, inst)
963 return not self.output.should_exit
964 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
657 return handle_graph_break(self, inst, speculation.reason)
658 try:
--> 659 return inner_fn(self, inst)
660 except Unsupported as excp:
661 if self.generic_context_manager_depth > 0:
662 # We don't support graph break under GenericContextWrappingVariable,
663 # If there is, we roll back to the checkpoint and fall back.
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1736, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
1734 # Map to a dictionary of str -> VariableTracker
1735 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1736 self.call_function(fn, argsvars.items, kwargsvars)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
896 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
906 ctx = (
907 record_nn_module_stack(
908 str(id(mod)), self.get_nn_module_stack_source(), tx, mod
(...) 911 else nullcontext()
912 )
913 with ctx:
--> 914 return variables.UserFunctionVariable(fn, source=source).call_function(
915 tx, [self] + list(args), kwargs
916 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
315 with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
316 return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
112 def call_function(
113 self,
114 tx: "InstructionTranslator",
115 args: "List[VariableTracker]",
116 kwargs: "Dict[str, VariableTracker]",
117 ) -> "VariableTracker":
--> 118 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
899 def inline_user_function_return(self, fn, args, kwargs):
900 """
901 A call to some user defined function by inlining it.
902 """
--> 903 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3069 @classmethod
3070 def inline_call(cls, parent, func, args, kwargs):
3071 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072 return cls.inline_call_(parent, func, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3196 try:
3197 with strict_ctx:
-> 3198 tracer.run()
3199 except exc.ObservedException as e:
3200 msg = f"Observed exception DURING INLING {code} : {e}"
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
1050 try:
1051 self.output.push_tx(self)
-> 1052 while self.step():
1053 pass
1054 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
959 self.update_block_stack(inst)
961 try:
--> 962 self.dispatch_table[inst.opcode](self, inst)
963 return not self.output.should_exit
964 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
657 return handle_graph_break(self, inst, speculation.reason)
658 try:
--> 659 return inner_fn(self, inst)
660 except Unsupported as excp:
661 if self.generic_context_manager_depth > 0:
662 # We don't support graph break under GenericContextWrappingVariable,
663 # If there is, we roll back to the checkpoint and fall back.
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1736, in InstructionTranslatorBase.CALL_FUNCTION_EX(self, inst)
1734 # Map to a dictionary of str -> VariableTracker
1735 kwargsvars = kwargsvars.keys_as_python_constant()
-> 1736 self.call_function(fn, argsvars.items, kwargsvars)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
896 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:170, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
166 @functools.wraps(getattr(VariableTracker, name))
167 def realize_and_forward(
168 self: LazyVariableTracker, *args: Any, **kwargs: Any
169 ) -> Any:
--> 170 return getattr(self.realize(), name)(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
906 ctx = (
907 record_nn_module_stack(
908 str(id(mod)), self.get_nn_module_stack_source(), tx, mod
(...) 911 else nullcontext()
912 )
913 with ctx:
--> 914 return variables.UserFunctionVariable(fn, source=source).call_function(
915 tx, [self] + list(args), kwargs
916 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
315 with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
316 return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
112 def call_function(
113 self,
114 tx: "InstructionTranslator",
115 args: "List[VariableTracker]",
116 kwargs: "Dict[str, VariableTracker]",
117 ) -> "VariableTracker":
--> 118 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
899 def inline_user_function_return(self, fn, args, kwargs):
900 """
901 A call to some user defined function by inlining it.
902 """
--> 903 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3069 @classmethod
3070 def inline_call(cls, parent, func, args, kwargs):
3071 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072 return cls.inline_call_(parent, func, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3196 try:
3197 with strict_ctx:
-> 3198 tracer.run()
3199 except exc.ObservedException as e:
3200 msg = f"Observed exception DURING INLING {code} : {e}"
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
1050 try:
1051 self.output.push_tx(self)
-> 1052 while self.step():
1053 pass
1054 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
959 self.update_block_stack(inst)
961 try:
--> 962 self.dispatch_table[inst.opcode](self, inst)
963 return not self.output.should_exit
964 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
657 return handle_graph_break(self, inst, speculation.reason)
658 try:
--> 659 return inner_fn(self, inst)
660 except Unsupported as excp:
661 if self.generic_context_manager_depth > 0:
662 # We don't support graph break under GenericContextWrappingVariable,
663 # If there is, we roll back to the checkpoint and fall back.
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2341, in InstructionTranslatorBase.CALL(self, inst)
2339 @break_graph_if_unsupported(push=1)
2340 def CALL(self, inst):
-> 2341 self._call(inst)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2335, in InstructionTranslatorBase._call(self, inst, call_kw)
2330 kwargs = {}
2332 try:
2333 # if call_function fails, need to set kw_names to None, otherwise
2334 # a subsequent call may have self.kw_names set to an old value
-> 2335 self.call_function(fn, args, kwargs)
2336 finally:
2337 self.kw_names = None
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
896 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:170, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
166 @functools.wraps(getattr(VariableTracker, name))
167 def realize_and_forward(
168 self: LazyVariableTracker, *args: Any, **kwargs: Any
169 ) -> Any:
--> 170 return getattr(self.realize(), name)(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
906 ctx = (
907 record_nn_module_stack(
908 str(id(mod)), self.get_nn_module_stack_source(), tx, mod
(...) 911 else nullcontext()
912 )
913 with ctx:
--> 914 return variables.UserFunctionVariable(fn, source=source).call_function(
915 tx, [self] + list(args), kwargs
916 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
315 with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
316 return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
112 def call_function(
113 self,
114 tx: "InstructionTranslator",
115 args: "List[VariableTracker]",
116 kwargs: "Dict[str, VariableTracker]",
117 ) -> "VariableTracker":
--> 118 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
899 def inline_user_function_return(self, fn, args, kwargs):
900 """
901 A call to some user defined function by inlining it.
902 """
--> 903 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3069 @classmethod
3070 def inline_call(cls, parent, func, args, kwargs):
3071 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072 return cls.inline_call_(parent, func, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3196 try:
3197 with strict_ctx:
-> 3198 tracer.run()
3199 except exc.ObservedException as e:
3200 msg = f"Observed exception DURING INLING {code} : {e}"
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
1050 try:
1051 self.output.push_tx(self)
-> 1052 while self.step():
1053 pass
1054 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
959 self.update_block_stack(inst)
961 try:
--> 962 self.dispatch_table[inst.opcode](self, inst)
963 return not self.output.should_exit
964 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
657 return handle_graph_break(self, inst, speculation.reason)
658 try:
--> 659 return inner_fn(self, inst)
660 except Unsupported as excp:
661 if self.generic_context_manager_depth > 0:
662 # We don't support graph break under GenericContextWrappingVariable,
663 # If there is, we roll back to the checkpoint and fall back.
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2341, in InstructionTranslatorBase.CALL(self, inst)
2339 @break_graph_if_unsupported(push=1)
2340 def CALL(self, inst):
-> 2341 self._call(inst)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2335, in InstructionTranslatorBase._call(self, inst, call_kw)
2330 kwargs = {}
2332 try:
2333 # if call_function fails, need to set kw_names to None, otherwise
2334 # a subsequent call may have self.kw_names set to an old value
-> 2335 self.call_function(fn, args, kwargs)
2336 finally:
2337 self.kw_names = None
[... skipping similar frames: InstructionTranslatorBase.call_function at line 897 (3 times), InstructionTranslatorBase.CALL at line 2341 (2 times), InstructionTranslatorBase._call at line 2335 (2 times), UnspecializedNNModuleVariable.call_function at line 914 (2 times), UserFunctionVariable.call_function at line 317 (2 times), BaseUserFunctionVariable.call_function at line 118 (2 times), InliningInstructionTranslator.inline_call at line 3072 (2 times), InliningInstructionTranslator.inline_call_ at line 3198 (2 times), InstructionTranslatorBase.inline_user_function_return at line 903 (2 times), InstructionTranslatorBase.run at line 1052 (2 times), InstructionTranslatorBase.step at line 962 (2 times), break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper at line 659 (2 times), _create_realize_and_forward.<locals>.realize_and_forward at line 170 (1 times)]
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/lazy.py:170, in _create_realize_and_forward.<locals>.realize_and_forward(self, *args, **kwargs)
166 @functools.wraps(getattr(VariableTracker, name))
167 def realize_and_forward(
168 self: LazyVariableTracker, *args: Any, **kwargs: Any
169 ) -> Any:
--> 170 return getattr(self.realize(), name)(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/nn_module.py:914, in UnspecializedNNModuleVariable.call_function(self, tx, args, kwargs)
906 ctx = (
907 record_nn_module_stack(
908 str(id(mod)), self.get_nn_module_stack_source(), tx, mod
(...) 911 else nullcontext()
912 )
913 with ctx:
--> 914 return variables.UserFunctionVariable(fn, source=source).call_function(
915 tx, [self] + list(args), kwargs
916 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:317, in UserFunctionVariable.call_function(self, tx, args, kwargs)
315 with torch._dynamo.side_effects.allow_side_effects_under_checkpoint(tx):
316 return super().call_function(tx, args, kwargs)
--> 317 return super().call_function(tx, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/functions.py:118, in BaseUserFunctionVariable.call_function(self, tx, args, kwargs)
112 def call_function(
113 self,
114 tx: "InstructionTranslator",
115 args: "List[VariableTracker]",
116 kwargs: "Dict[str, VariableTracker]",
117 ) -> "VariableTracker":
--> 118 return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:903, in InstructionTranslatorBase.inline_user_function_return(self, fn, args, kwargs)
899 def inline_user_function_return(self, fn, args, kwargs):
900 """
901 A call to some user defined function by inlining it.
902 """
--> 903 return InliningInstructionTranslator.inline_call(self, fn, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3072, in InliningInstructionTranslator.inline_call(cls, parent, func, args, kwargs)
3069 @classmethod
3070 def inline_call(cls, parent, func, args, kwargs):
3071 with patch.dict(counters, {"unimplemented": counters["inline_call"]}):
-> 3072 return cls.inline_call_(parent, func, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:3198, in InliningInstructionTranslator.inline_call_(parent, func, args, kwargs)
3196 try:
3197 with strict_ctx:
-> 3198 tracer.run()
3199 except exc.ObservedException as e:
3200 msg = f"Observed exception DURING INLING {code} : {e}"
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:1052, in InstructionTranslatorBase.run(self)
1050 try:
1051 self.output.push_tx(self)
-> 1052 while self.step():
1053 pass
1054 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:962, in InstructionTranslatorBase.step(self)
959 self.update_block_stack(inst)
961 try:
--> 962 self.dispatch_table[inst.opcode](self, inst)
963 return not self.output.should_exit
964 except TensorifyScalarRestartAnalysis:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:659, in break_graph_if_unsupported.<locals>.decorator.<locals>.wrapper(self, inst)
657 return handle_graph_break(self, inst, speculation.reason)
658 try:
--> 659 return inner_fn(self, inst)
660 except Unsupported as excp:
661 if self.generic_context_manager_depth > 0:
662 # We don't support graph break under GenericContextWrappingVariable,
663 # If there is, we roll back to the checkpoint and fall back.
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2341, in InstructionTranslatorBase.CALL(self, inst)
2339 @break_graph_if_unsupported(push=1)
2340 def CALL(self, inst):
-> 2341 self._call(inst)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:2335, in InstructionTranslatorBase._call(self, inst, call_kw)
2330 kwargs = {}
2332 try:
2333 # if call_function fails, need to set kw_names to None, otherwise
2334 # a subsequent call may have self.kw_names set to an old value
-> 2335 self.call_function(fn, args, kwargs)
2336 finally:
2337 self.kw_names = None
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/symbolic_convert.py:897, in InstructionTranslatorBase.call_function(self, fn, args, kwargs)
895 if inner_fn and callable(inner_fn) and is_forbidden(inner_fn):
896 raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}")
--> 897 self.push(fn.call_function(self, args, kwargs))
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/torch.py:953, in TorchInGraphFunctionVariable.call_function(self, tx, args, kwargs)
944 if "out" in kwargs and isinstance(kwargs["out"], variables.TensorVariable):
945 # Calling fake tensor propagation can mutate the out= tensor in
946 # tx.output.tracked_fakes. tracked_fakes are used to apply
(...) 949 # guards. So save the shape now, and check later if it has
950 # changed. If it has, graph break.
951 fake_out_shape = kwargs["out"].proxy.node.meta["example_value"].shape
--> 953 tensor_variable = wrap_fx_proxy(
954 tx=tx,
955 proxy=tx.output.create_proxy(
956 "call_function",
957 fn_,
958 *proxy_args_kwargs(args, kwargs),
959 ),
960 )
962 if (
963 isinstance(tensor_variable, TensorVariable)
964 and "requires_grad" in kwargs
965 and kwargs["requires_grad"].as_python_constant()
966 ):
967 unimplemented(
968 """factory functions that return tensors that require grad are not supported.
969 Either create the tensor outside the compiled region, or do not set the tensor to require_grad"""
970 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2153, in wrap_fx_proxy(tx, proxy, example_value, subclass_type, **options)
2145 kwargs = {
2146 "tx": tx,
2147 "proxy": proxy,
(...) 2150 **options,
2151 }
2152 if subclass_type is None:
-> 2153 return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
2154 else:
2155 result = wrap_fx_proxy_cls(target_cls=TensorWithTFOverrideVariable, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2219, in wrap_fx_proxy_cls(target_cls, tx, proxy, example_value, subclass_type, **options)
2215 def wrap_fx_proxy_cls(
2216 target_cls, tx, proxy, example_value=None, subclass_type=None, **options
2217 ):
2218 if example_value is None:
-> 2219 return _wrap_fx_proxy(
2220 target_cls, tx, proxy, example_value, subclass_type, **options
2221 )
2222 elif isinstance(example_value, torch.Tensor):
2223 return _wrap_fx_preexisting_tensor(
2224 target_cls, tx, proxy, example_value, subclass_type, **options
2225 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/variables/builder.py:2315, in _wrap_fx_proxy(target_cls, tx, proxy, example_value, subclass_type, **options)
2310 # See NOTE: [Deferring tensor pack/unpack hooks until runtime]
2311 with torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing():
2312 # with preserve_rng_state():
2313 # only allow_non_graph_fake in this instance because we handle the non-fake
2314 # cases properly below.
-> 2315 example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
2317 return handle_traced_output(
2318 example_value, tx, proxy, options, subclass_type, target_cls
2319 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2536, in get_fake_value(node, tx, allow_non_graph_fake)
2533 elif isinstance(cause, TypeError) and "argument" in str(cause):
2534 unimplemented(f"TypeError {node.target}: {cause}")
-> 2536 raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
2538 if not allow_non_graph_fake:
2539 _ = pytree.tree_map_only(
2540 torch.Tensor, functools.partial(ensure_graph_fake, tx=tx), ret_val
2541 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2471, in get_fake_value(node, tx, allow_non_graph_fake)
2469 try:
2470 with tx.fake_mode, enable_python_dispatcher():
-> 2471 ret_val = wrap_fake_exception(
2472 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
2473 )
2474 except Unsupported:
2475 raise
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2017, in wrap_fake_exception(fn)
2015 def wrap_fake_exception(fn):
2016 try:
-> 2017 return fn()
2018 except UnsupportedFakeTensorException as e:
2019 from .exc import unimplemented
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2472, in get_fake_value.<locals>.<lambda>()
2469 try:
2470 with tx.fake_mode, enable_python_dispatcher():
2471 ret_val = wrap_fake_exception(
-> 2472 lambda: run_node(tx.output, node, args, kwargs, nnmodule)
2473 )
2474 except Unsupported:
2475 raise
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2604, in run_node(tracer, node, args, kwargs, nnmodule)
2602 unimplemented(make_error_message(e), from_exc=e)
2603 except Exception as e:
-> 2604 raise RuntimeError(make_error_message(e)).with_traceback(
2605 e.__traceback__
2606 ) from e
2608 raise AssertionError(op)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_dynamo/utils.py:2586, in run_node(tracer, node, args, kwargs, nnmodule)
2584 try:
2585 if op == "call_function":
-> 2586 return node.target(*args, **kwargs)
2587 elif op == "call_method":
2588 return getattr(args[0], node.target)(*args[1:], **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/utils/_stats.py:21, in count.<locals>.wrapper(*args, **kwargs)
19 simple_call_counter[fn.__qualname__] = 0
20 simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1
---> 21 return fn(*args, **kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1276, in FakeTensorMode.__torch_dispatch__(self, func, types, args, kwargs)
1272 assert (
1273 torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) is None
1274 ), func
1275 try:
-> 1276 return self.dispatch(func, types, args, kwargs)
1277 except TypeError:
1278 log.exception("fake tensor raised TypeError")
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1816, in FakeTensorMode.dispatch(self, func, types, args, kwargs)
1813 return func(*args, **kwargs)
1815 if self.cache_enabled:
-> 1816 return self._cached_dispatch_impl(func, types, args, kwargs)
1817 else:
1818 return self._dispatch_impl(func, types, args, kwargs)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:1377, in FakeTensorMode._cached_dispatch_impl(self, func, types, args, kwargs)
1375 else:
1376 self._validate_cache_key(func, args, kwargs)
-> 1377 output = self._dispatch_impl(func, types, args, kwargs)
1378 entry = self._make_cache_entry(state, key, func, args, kwargs, output)
1379 key.strip_shape_env()
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2392, in FakeTensorMode._dispatch_impl(self, func, types, args, kwargs)
2388 log.exception("failed while attempting to run meta for %s", func)
2389 raise
2391 return maybe_propagate_real_tensors(
-> 2392 self.wrap_meta_outputs_with_default_device_logic(
2393 r, func, flat_args, device=kwargs.get("device")
2394 )
2395 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2514, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic(self, r, func, flat_args, device)
2511 else:
2512 return e
-> 2514 return tree_map(wrap, r)
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/utils/_pytree.py:991, in tree_map(func, tree, is_leaf, *rests)
989 leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
990 flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
--> 991 return treespec.unflatten(map(func, *flat_args))
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/utils/_pytree.py:830, in TreeSpec.unflatten(self, leaves)
828 def unflatten(self, leaves: Iterable[Any]) -> PyTree:
829 if not isinstance(leaves, (list, tuple)):
--> 830 leaves = list(leaves)
831 if len(leaves) != self.num_leaves:
832 raise ValueError(
833 f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} "
834 f"but the spec refers to a pytree that holds {self.num_leaves} "
835 f"items ({self}).",
836 )
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:2492, in FakeTensorMode.wrap_meta_outputs_with_default_device_logic.<locals>.wrap(e)
2486 return e
2488 if common_device is None:
2489 (
2490 common_device,
2491 has_scalar_only_inputs,
-> 2492 ) = FakeTensor._find_common_device(func, flat_args)
2494 is_our_fake = self.is_our_fake(e)
2495 if is_our_fake:
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:890, in FakeTensor._find_common_device(func, flat_args)
885 raise RuntimeError(
886 f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
887 )
889 for arg in flat_args:
--> 890 merge_devices(arg)
892 # some functions that allow Python numbers to bind to Tensors
893 # if we have failed to find a device, and we're running one of these operators,
894 # we must have scalar only inputs
895 if should_allow_numbers_as_tensors(func) and common_device is None:
896 # ops with scalar only inputs always have result on cpu
File ~/.cache/pypoetry/virtualenvs/nhwDQ0y-py3.12/lib/python3.12/site-packages/torch/_subclasses/fake_tensor.py:885, in FakeTensor._find_common_device.<locals>.merge_devices(t)
881 return
883 # mismatching devices of non-zero dim tensors, throw
884 # This might be valid behavior and need to be explicitly modeled, e.g. reshape_as
--> 885 raise RuntimeError(
886 f"Unhandled FakeTensor Device Propagation for {func}, found two different devices {common_device}, {t.device}"
887 )
TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(2, 16, 3, 64),
grad_fn=<PermuteBackward0>), FakeTensor(..., device='cuda:0', size=(2, 16, 3, 64),
grad_fn=<PermuteBackward0>), FakeTensor(..., device='cuda:0', size=(2, 16, 3, 64),
grad_fn=<PermuteBackward0>)), **{'attn_mask': FakeTensor(..., size=(2, 1, 3, 3)), 'dropout_p': 0.0, 'is_causal': False}):
Unhandled FakeTensor Device Propagation for aten._scaled_dot_product_efficient_attention.default, found two different devices cuda:0, cpu