Update README.md
Browse files
README.md
CHANGED
@@ -61,13 +61,13 @@ MODEL_NAME = "p1atdev/wd-swinv2-tagger-v3-hf"
|
|
61 |
model = AutoModelForImageClassification.from_pretrained(
|
62 |
MODEL_NAME,
|
63 |
)
|
64 |
-
model.eval()
|
65 |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
66 |
|
67 |
image = Image.open("sample.webp")
|
68 |
inputs = processor.preprocess(image, return_tensors="pt")
|
69 |
|
70 |
-
|
|
|
71 |
logits = torch.sigmoid(outputs.logits[0])
|
72 |
|
73 |
# get probabilities
|
|
|
61 |
model = AutoModelForImageClassification.from_pretrained(
|
62 |
MODEL_NAME,
|
63 |
)
|
|
|
64 |
processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
65 |
|
66 |
image = Image.open("sample.webp")
|
67 |
inputs = processor.preprocess(image, return_tensors="pt")
|
68 |
|
69 |
+
with torch.no_grad():
|
70 |
+
outputs = model(**inputs.to(model.device, model.dtype))
|
71 |
logits = torch.sigmoid(outputs.logits[0])
|
72 |
|
73 |
# get probabilities
|