Skip to content

Commit 964218c

Browse files
committed
code fomatted using Ruff - Adithya S K
1 parent 0601ed7 commit 964218c

29 files changed

+1542
-938
lines changed

download.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
"""
22
Script to download models
33
"""
4+
45
import argparse
56
from omniparse import load_omnimodel
67

8+
79
def download_models():
8-
910
parser = argparse.ArgumentParser(description="Download models for omniparse")
10-
11-
parser.add_argument("--documents", action='store_true', help="Load document models")
12-
parser.add_argument("--media", action='store_true', help="Load media models")
13-
parser.add_argument("--web", action='store_true', help="Load web models")
11+
12+
parser.add_argument("--documents", action="store_true", help="Load document models")
13+
parser.add_argument("--media", action="store_true", help="Load media models")
14+
parser.add_argument("--web", action="store_true", help="Load web models")
1415
args = parser.parse_args()
15-
16-
16+
1717
load_omnimodel(args.documents, args.media, args.web)
18-
1918

20-
if __name__ == '__main__':
21-
download_models()
19+
20+
if __name__ == "__main__":
21+
download_models()

omniparse/__init__.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
URL: https://github.com/VikParuchuri/marker/blob/master/LICENSE
1414
1515
Description:
16-
This section of the code was adapted from the marker repository to load all the OCR, layout and reading order detection models.
16+
This section of the code was adapted from the marker repository to load all the OCR, layout and reading order detection models.
1717
All credits for the original implementation go to VikParuchuri.
1818
"""
1919

2020
import torch
21-
from typing import Any
21+
from typing import Any
2222
from pydantic import BaseModel
2323
from transformers import AutoProcessor, AutoModelForCausalLM
2424
import whisper
@@ -35,8 +35,10 @@ class SharedState(BaseModel):
3535
whisper_model: Any = None
3636
crawler: Any = None
3737

38+
3839
shared_state = SharedState()
3940

41+
4042
def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
4143
global shared_state
4244
print_omniparse_text_art()
@@ -46,22 +48,28 @@ def load_omnimodel(load_documents: bool, load_media: bool, load_web: bool):
4648
shared_state.model_list = load_all_models()
4749
print("[LOG] ✅ Loading Vision Model")
4850
# if device == "cuda":
49-
shared_state.vision_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device)
50-
shared_state.vision_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
51-
51+
shared_state.vision_model = AutoModelForCausalLM.from_pretrained(
52+
"microsoft/Florence-2-base", trust_remote_code=True
53+
).to(device)
54+
shared_state.vision_processor = AutoProcessor.from_pretrained(
55+
"microsoft/Florence-2-base", trust_remote_code=True
56+
)
57+
5258
if load_media:
5359
print("[LOG] ✅ Loading Audio Model")
5460
shared_state.whisper_model = whisper.load_model("small")
55-
61+
5662
if load_web:
5763
print("[LOG] ✅ Loading Web Crawler")
5864
shared_state.crawler = WebCrawler(verbose=True)
5965

66+
6067
def get_shared_state():
6168
return shared_state
6269

70+
6371
def get_active_models():
6472
print(shared_state)
6573
# active_models = [key for key, value in shared_state.dict().items() if value is not None]
6674
# print(f"These are the active model : {active_models}")
67-
return shared_state
75+
return shared_state

omniparse/chunking/__init__.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,22 @@
55
from nltk.tokenize import sent_tokenize
66
from omniparse.web.model_loader import load_nltk_punkt
77

8+
89
# Define the abstract base class for chunking strategies
910
class ChunkingStrategy(ABC):
10-
1111
@abstractmethod
1212
def chunk(self, text: str) -> list:
1313
"""
1414
Abstract method to chunk the given text.
1515
"""
1616
pass
17-
17+
18+
1819
# Regex-based chunking
1920
class RegexChunking(ChunkingStrategy):
2021
def __init__(self, patterns=None, **kwargs):
2122
if patterns is None:
22-
patterns = [r'\n\n'] # Default split pattern
23+
patterns = [r"\n\n"] # Default split pattern
2324
self.patterns = patterns
2425

2526
def chunk(self, text: str) -> list:
@@ -30,24 +31,26 @@ def chunk(self, text: str) -> list:
3031
new_paragraphs.extend(re.split(pattern, paragraph))
3132
paragraphs = new_paragraphs
3233
return paragraphs
33-
34-
# NLP-based sentence chunking
34+
35+
36+
# NLP-based sentence chunking
3537
class NlpSentenceChunking(ChunkingStrategy):
3638
def __init__(self, **kwargs):
3739
load_nltk_punkt()
3840
pass
3941

40-
def chunk(self, text: str) -> list:
42+
def chunk(self, text: str) -> list:
4143
sentences = sent_tokenize(text)
42-
sens = [sent.strip() for sent in sentences]
43-
44+
sens = [sent.strip() for sent in sentences]
45+
4446
return list(set(sens))
45-
47+
48+
4649
# Topic-based segmentation using TextTiling
4750
class TopicSegmentationChunking(ChunkingStrategy):
48-
4951
def __init__(self, num_keywords=3, **kwargs):
5052
import nltk as nl
53+
5154
self.tokenizer = nl.toknize.TextTilingTokenizer()
5255
self.num_keywords = num_keywords
5356

@@ -59,8 +62,14 @@ def chunk(self, text: str) -> list:
5962
def extract_keywords(self, text: str) -> list:
6063
# Tokenize and remove stopwords and punctuation
6164
import nltk as nl
65+
6266
tokens = nl.toknize.word_tokenize(text)
63-
tokens = [token.lower() for token in tokens if token not in nl.corpus.stopwords.words('english') and token not in string.punctuation]
67+
tokens = [
68+
token.lower()
69+
for token in tokens
70+
if token not in nl.corpus.stopwords.words("english")
71+
and token not in string.punctuation
72+
]
6473

6574
# Calculate frequency distribution
6675
freq_dist = Counter(tokens)
@@ -71,18 +80,25 @@ def chunk_with_topics(self, text: str) -> list:
7180
# Segment the text into topics
7281
segments = self.chunk(text)
7382
# Extract keywords for each topic segment
74-
segments_with_topics = [(segment, self.extract_keywords(segment)) for segment in segments]
83+
segments_with_topics = [
84+
(segment, self.extract_keywords(segment)) for segment in segments
85+
]
7586
return segments_with_topics
76-
87+
88+
7789
# Fixed-length word chunks
7890
class FixedLengthWordChunking(ChunkingStrategy):
7991
def __init__(self, chunk_size=100, **kwargs):
8092
self.chunk_size = chunk_size
8193

8294
def chunk(self, text: str) -> list:
8395
words = text.split()
84-
return [' '.join(words[i:i + self.chunk_size]) for i in range(0, len(words), self.chunk_size)]
85-
96+
return [
97+
" ".join(words[i : i + self.chunk_size])
98+
for i in range(0, len(words), self.chunk_size)
99+
]
100+
101+
86102
# Sliding window chunking
87103
class SlidingWindowChunking(ChunkingStrategy):
88104
def __init__(self, window_size=100, step=50, **kwargs):
@@ -93,7 +109,5 @@ def chunk(self, text: str) -> list:
93109
words = text.split()
94110
chunks = []
95111
for i in range(0, len(words), self.step):
96-
chunks.append(' '.join(words[i:i + self.window_size]))
112+
chunks.append(" ".join(words[i : i + self.window_size]))
97113
return chunks
98-
99-

0 commit comments

Comments
 (0)