Skip to content

Commit 5c156b5

Browse files
committed
Adding modelPath and vocabPath to pheye for local inference.
1 parent b9e13df commit 5c156b5

4 files changed

Lines changed: 148 additions & 1 deletion

File tree

phileas/filters/ph_eye_filter.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,20 @@
2626
class PhEyeFilter(BaseFilter):
2727
def __init__(self, filter_config):
2828
super().__init__(FilterType.PH_EYE, filter_config)
29+
self._model = None
2930

3031
def filter(self, text: str, context: str = "default") -> List[Span]:
32+
model_path = getattr(self.filter_config, "model_path", "")
33+
vocab_path = getattr(self.filter_config, "vocab_path", "")
34+
labels = getattr(self.filter_config, "labels", ["PERSON"])
35+
36+
if model_path and vocab_path:
37+
return self._local_filter(text, context, model_path, vocab_path, labels)
38+
3139
endpoint = getattr(self.filter_config, "endpoint", "")
3240
if not endpoint:
3341
return []
3442

35-
labels = getattr(self.filter_config, "labels", ["PERSON"])
3643
thresholds = getattr(self.filter_config, "thresholds", {})
3744
bearer_token = getattr(self.filter_config, "bearer_token", "")
3845
timeout = getattr(self.filter_config, "timeout", 30) or 30
@@ -42,6 +49,8 @@ def filter(self, text: str, context: str = "default") -> List[Span]:
4249
"context": context,
4350
"piece": 0,
4451
"labels": list(labels),
52+
"modelPath": model_path,
53+
"vocabPath": vocab_path,
4554
}).encode("utf-8")
4655

4756
req = urllib.request.Request(
@@ -107,3 +116,63 @@ def filter(self, text: str, context: str = "default") -> List[Span]:
107116
))
108117

109118
return spans
119+
120+
def _local_filter(self, text: str, context: str, model_path: str, vocab_path: str, labels: List[str]) -> List[Span]:
121+
if not self._model:
122+
onnx = False
123+
if model_path.endswith(".onnx"):
124+
onnx = True
125+
126+
try:
127+
from gliner import GLiNER
128+
except ImportError:
129+
raise ImportError("The 'gliner' package is required for local inference. Install it with 'pip install gliner'.")
130+
131+
self._model = GLiNER.from_pretrained(model_path, onnx=onnx, vocab_path=vocab_path)
132+
133+
ph_eye_spans = self._model.predict_entities(text, labels)
134+
135+
strategies = self._get_strategies()
136+
strategy = strategies[0] if strategies else None
137+
ignored_terms = set(self._get_ignored())
138+
thresholds = getattr(self.filter_config, "thresholds", {})
139+
140+
spans: List[Span] = []
141+
for item in ph_eye_spans:
142+
label = item.get("label", "")
143+
score = float(item.get("score", 0.0))
144+
span_text = item.get("text", "")
145+
start = int(item.get("start", 0))
146+
end = int(item.get("end", 0))
147+
148+
if labels and label not in labels:
149+
continue
150+
151+
threshold = thresholds.get(label.upper(), 0.0)
152+
if score < threshold:
153+
continue
154+
155+
if span_text in ignored_terms:
156+
continue
157+
158+
if label.upper() == "PERSON":
159+
filter_type = "person"
160+
else:
161+
filter_type = label.lower() if label else FilterType.PH_EYE
162+
163+
replacement = (
164+
strategy.get_replacement(filter_type, span_text) if strategy else span_text
165+
)
166+
167+
spans.append(Span(
168+
character_start=start,
169+
character_end=end,
170+
filter_type=filter_type,
171+
context=context,
172+
confidence=score,
173+
text=span_text,
174+
replacement=replacement,
175+
ignored=False,
176+
))
177+
178+
return spans

phileas/policy/identifiers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ class PhEyeFilterConfig:
166166
timeout: int = 30
167167
labels: List[str] = field(default_factory=lambda: ["PERSON"])
168168
remove_punctuation: bool = False
169+
model_path: str = ""
170+
vocab_path: str = ""
169171
thresholds: dict = field(default_factory=dict)
170172
ph_eye_filter_strategies: List[FilterStrategy] = field(default_factory=_default_strategies)
171173
ignored: List[str] = field(default_factory=list)
@@ -369,6 +371,8 @@ def from_dict(cls, data: dict) -> "Identifiers":
369371
timeout=d.get("timeout", 30),
370372
labels=d.get("labels", ["PERSON"]),
371373
remove_punctuation=d.get("removePunctuation", False),
374+
model_path=d.get("modelPath", ""),
375+
vocab_path=d.get("vocabPath", ""),
372376
thresholds=d.get("thresholds", {}),
373377
ph_eye_filter_strategies=_strategies_from_dict(d, "phEyeFilterStrategies"),
374378
ignored=d.get("ignored", []),
@@ -524,6 +528,8 @@ def to_dict(self) -> dict:
524528
"timeout": cfg.timeout,
525529
"labels": cfg.labels,
526530
"removePunctuation": cfg.remove_punctuation,
531+
"modelPath": cfg.model_path,
532+
"vocabPath": cfg.vocab_path,
527533
"thresholds": cfg.thresholds,
528534
"phEyeFilterStrategies": [s.to_dict() for s in cfg.ph_eye_filter_strategies],
529535
"ignored": cfg.ignored,

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ dev = [
6767
server = [
6868
"flask>=3.0",
6969
]
70+
gliner = [
71+
"gliner>=0.1.0",
72+
"onnxruntime>=1.16.0",
73+
]
7074

7175
[project.scripts]
7276
phileas-server = "phileas.server:main"

tests/test_filters.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,6 +1064,74 @@ def test_url_error_raises_ioerror(self):
10641064
with pytest.raises(IOError):
10651065
f.filter("John Smith was here.")
10661066

1067+
def test_local_filter_gliner(self):
1068+
import unittest.mock as mock
1069+
config = PhEyeFilterConfig(
1070+
model_path="/path/to/model.bin",
1071+
vocab_path="/path/to/vocab.txt",
1072+
labels=["PERSON"]
1073+
)
1074+
f = PhEyeFilter(config)
1075+
1076+
mock_gliner_class = mock.MagicMock()
1077+
mock_model = mock.MagicMock()
1078+
mock_gliner_class.from_pretrained.return_value = mock_model
1079+
1080+
# Mocking the response from GLiNER
1081+
mock_model.predict_entities.return_value = [
1082+
{"start": 0, "end": 10, "label": "PERSON", "text": "John Smith", "score": 0.95}
1083+
]
1084+
1085+
with mock.patch.dict("sys.modules", {"gliner": mock.MagicMock()}):
1086+
import gliner
1087+
with mock.patch("gliner.GLiNER", mock_gliner_class):
1088+
spans = f.filter("John Smith was here.")
1089+
1090+
assert len(spans) == 1
1091+
assert spans[0].text == "John Smith"
1092+
assert spans[0].filter_type == "person"
1093+
assert spans[0].confidence == 0.95
1094+
1095+
mock_gliner_class.from_pretrained.assert_called_once_with("/path/to/model.bin", onnx=False, vocab_path="/path/to/vocab.txt")
1096+
1097+
def test_local_filter_onnx(self):
1098+
import unittest.mock as mock
1099+
config = PhEyeFilterConfig(
1100+
model_path="/path/to/model.onnx",
1101+
vocab_path="/path/to/vocab.txt",
1102+
labels=["PERSON"]
1103+
)
1104+
f = PhEyeFilter(config)
1105+
1106+
mock_gliner_class = mock.MagicMock()
1107+
mock_model = mock.MagicMock()
1108+
mock_gliner_class.from_pretrained.return_value = mock_model
1109+
1110+
mock_model.predict_entities.return_value = [
1111+
{"start": 0, "end": 10, "label": "PERSON", "text": "John Smith", "score": 0.95}
1112+
]
1113+
1114+
with mock.patch.dict("sys.modules", {"gliner": mock.MagicMock()}):
1115+
import gliner
1116+
with mock.patch("gliner.GLiNER", mock_gliner_class):
1117+
spans = f.filter("John Smith was here.")
1118+
1119+
assert len(spans) == 1
1120+
mock_gliner_class.from_pretrained.assert_called_once_with("/path/to/model.onnx", onnx=True, vocab_path="/path/to/vocab.txt")
1121+
1122+
def test_local_filter_import_error(self):
1123+
import unittest.mock as mock
1124+
config = PhEyeFilterConfig(
1125+
model_path="/path/to/model.bin",
1126+
vocab_path="/path/to/vocab.txt"
1127+
)
1128+
f = PhEyeFilter(config)
1129+
1130+
with mock.patch.dict("sys.modules", {"gliner": None}):
1131+
with pytest.raises(ImportError) as excinfo:
1132+
f.filter("Some text.")
1133+
assert "The 'gliner' package is required" in str(excinfo.value)
1134+
10671135
def test_policy_json_ph_eye(self):
10681136
from phileas.policy.policy import Policy
10691137
policy_json = json.dumps({

0 commit comments

Comments
 (0)