Skip to content

Commit 758bdeb

Browse files
authored
Download model file to current work dir by default and fix chinese file name issue (#103)
- Introduced logging in download_model.py - Added logging in upload_repo.py - Updated CLI command help text - Replaced print statements with logger in file_upload.py - Enhanced logging in repository.py - Improved logging in snapshot_download.py - Added cache path constants in local_folder.py - Cleaned up cache after upload in main.py - Added error logging for dataset and model info retrieval in utils.py - Migrated setup configuration to pyproject.toml
1 parent 646251e commit 758bdeb

File tree

11 files changed

+146
-46
lines changed

11 files changed

+146
-46
lines changed

examples/download_model.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from pycsghub.snapshot_download import snapshot_download
23
# token = "your access token"
34
token = None
@@ -9,6 +10,14 @@
910
allow_patterns = ["*.json"]
1011
ignore_patterns = ["tokenizer.json"]
1112

13+
# set log level
14+
logging.basicConfig(
15+
level=getattr(logging, "INFO"),
16+
format='%(asctime)s - %(levelname)s - %(message)s',
17+
datefmt='%Y-%m-%d %H:%M:%S',
18+
handlers=[logging.StreamHandler()]
19+
)
20+
1221
result = snapshot_download(
1322
repo_id,
1423
repo_type=repo_type,

examples/upload_repo.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
1+
import logging
12
from pycsghub.repository import Repository
23

34
# token = "your access token"
45
token = None
56

7+
# set log level
8+
logging.basicConfig(
9+
level=getattr(logging, "INFO"),
10+
format='%(asctime)s - %(levelname)s - %(message)s',
11+
datefmt='%Y-%m-%d %H:%M:%S',
12+
handlers=[logging.StreamHandler()]
13+
)
14+
615
r = Repository(
7-
repo_id="wanghh2003/ds15",
16+
repo_id="wanghh2000/ds16",
817
upload_path="/Users/hhwang/temp/bbb/jsonl",
9-
user_name="wanghh2003",
18+
user_name="wanghh2000",
1019
token=token,
1120
repo_type="dataset",
1221
)

pycsghub/cli.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def version_callback(value: bool):
4949
"source": typer.Option("--source", help="Specify the source of the repository (e.g. 'csg', 'hf', 'ms')."),
5050
}
5151

52-
@app.command(name="download", help="Download model/dataset from OpenCSG Hub", no_args_is_help=True)
52+
@app.command(name="download", help="Download model/dataset/space from OpenCSG Hub", no_args_is_help=True)
5353
def download(
5454
repo_id: Annotated[str, OPTIONS["repoID"]],
5555
repo_type: Annotated[RepoType, OPTIONS["repoType"]] = RepoType.MODEL,
@@ -245,7 +245,6 @@ def stop_finetune(
245245
}
246246
)
247247
def main(
248-
version: bool = OPTIONS["version"],
249248
log_level: str = OPTIONS["log_level"]
250249
):
251250
# for example: format='%(asctime)s - %(name)s:%(funcName)s:%(lineno)d - %(levelname)s - %(message)s',

pycsghub/file_upload.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from typing import Optional
44
from pycsghub.constants import (DEFAULT_REVISION)
55
from pycsghub.utils import (build_csg_headers, get_endpoint)
6+
import logging
7+
8+
logger = logging.getLogger(__name__)
69

710
def http_upload_file(
811
repo_id: str,
@@ -25,7 +28,7 @@ def http_upload_file(
2528
form_data = {'file_path': destination_path, 'branch': revision, 'message': 'upload' + file_path}
2629
response = requests.post(http_url, headers=post_headers, data=form_data, files=file_data)
2730
if response.status_code == 200:
28-
print(f"file '{file_path}' upload successfully.")
31+
logger.info(f"file '{file_path}' upload successfully.")
2932
else:
30-
print(f"fail to upload {file_path} with response code: {response.status_code}, error: {response.content.decode()}")
33+
logger.error(f"fail to upload {file_path} with response code: {response.status_code}, error: {response.content.decode()}")
3134

pycsghub/repository.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
from pycsghub.utils import (build_csg_headers,
1818
model_id_to_group_owner_name,
1919
get_endpoint)
20+
import logging
21+
22+
logger = logging.getLogger(__name__)
2023

2124
def ignore_folders(folder, contents):
2225
ignored = []
@@ -185,7 +188,7 @@ def create_new_branch(self):
185188
})
186189
response = requests.post(url, json=data, headers=headers)
187190
if response.status_code != 200:
188-
print(f"create branch on {url} response: {response.text}")
191+
logger.info(f"create branch on {url} response: {response.text}")
189192
response.raise_for_status()
190193
return response
191194

@@ -208,7 +211,7 @@ def create_new_repo(self):
208211
})
209212
response = requests.post(url, json=data, headers=headers)
210213
if response.status_code != 200:
211-
print(f"create repo on {url} response: {response.text}")
214+
logger.info(f"create repo on {url} response: {response.text}")
212215
response.raise_for_status()
213216
return response
214217

@@ -291,6 +294,7 @@ def track_large_files(self, work_dir: str, pattern: str = ".") -> List[str]:
291294
if filename in deleted_files:
292295
continue
293296

297+
logger.debug(f"Checking file {filename} for LFS tracking.")
294298
path_to_file = os.path.join(os.getcwd(), work_dir, filename)
295299
size_in_mb = os.path.getsize(path_to_file) / (1024 * 1024)
296300

@@ -302,11 +306,23 @@ def track_large_files(self, work_dir: str, pattern: str = ".") -> List[str]:
302306

303307
return files_to_be_tracked_with_lfs
304308

309+
def set_git_config_quotepath(self, work_dir: str) -> None:
310+
try:
311+
self.run_subprocess("git config --global core.quotepath false".split(), work_dir)
312+
except subprocess.CalledProcessError:
313+
try:
314+
self.run_subprocess("git config core.quotepath false".split(), work_dir)
315+
except subprocess.CalledProcessError:
316+
pass
317+
305318
def list_files_to_be_staged(self, work_dir: str, pattern: str = ".") -> List[str]:
306319
try:
320+
self.set_git_config_quotepath(work_dir)
307321
p = self.run_subprocess("git ls-files --exclude-standard -mo".split() + [pattern], work_dir)
308-
if len(p.stdout.strip()):
309-
files = p.stdout.strip().split("\n")
322+
output = p.stdout.strip()
323+
if len(output):
324+
logger.debug(f"Files to be staged: {output}")
325+
files = output.split("\n")
310326
else:
311327
files = []
312328
except subprocess.CalledProcessError as exc:

pycsghub/snapshot_download.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from pycsghub.constants import DEFAULT_REVISION, REPO_TYPES
1515
from pycsghub import utils
1616
from pycsghub.constants import REPO_TYPE_MODEL
17+
import logging
18+
19+
logger = logging.getLogger(__name__)
1720

1821
def snapshot_download(
1922
repo_id: str,
@@ -35,14 +38,18 @@ def snapshot_download(
3538
repo_type = REPO_TYPE_MODEL
3639
if repo_type not in REPO_TYPES:
3740
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
41+
3842
if cache_dir is None:
3943
cache_dir = get_cache_dir(repo_type=repo_type)
4044
if isinstance(cache_dir, Path):
4145
cache_dir = str(cache_dir)
46+
4247
temporary_cache_dir = os.path.join(cache_dir, 'temp')
4348
os.makedirs(temporary_cache_dir, exist_ok=True)
4449

45-
if local_dir is not None and isinstance(local_dir, Path):
50+
if local_dir is None:
51+
local_dir = os.getcwd()
52+
if isinstance(local_dir, Path):
4653
local_dir = str(local_dir)
4754

4855
group_or_owner, name = model_id_to_group_owner_name(repo_id)
@@ -83,7 +90,7 @@ def snapshot_download(
8390
repo_file_info = pack_repo_file_info(repo_file, revision)
8491
if cache.exists(repo_file_info):
8592
file_name = os.path.basename(repo_file_info['Path'])
86-
print(f"File {file_name} already in '{cache.get_root_location()}', skip downloading!")
93+
logger.info(f"File {file_name} already in '{cache.get_root_location()}', skip downloading!")
8794
continue
8895

8996
# get download url
@@ -95,6 +102,7 @@ def snapshot_download(
95102
endpoint=download_endpoint,
96103
source=source)
97104
# todo support parallel download api
105+
logger.debug(f"Downloading {repo_file} from {url}")
98106
http_get(
99107
url=url,
100108
local_dir=temp_cache_dir,
@@ -106,7 +114,7 @@ def snapshot_download(
106114
# todo using hash to check file integrity
107115
temp_file = os.path.join(temp_cache_dir, repo_file)
108116
savedFile = cache.put_file(repo_file_info, temp_file)
109-
print(f"Saved file to '{savedFile}'")
117+
logger.info(f"Saved file to '{savedFile}'")
110118

111119
cache.save_model_version(revision_info={'Revision': revision})
112120
return os.path.join(cache.get_root_location())

pycsghub/upload_large_folder/local_folder.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
key_lfs_uploaded_ids = "lfs_uploaded_ids"
2323
key_remote_oid = "remote_oid"
2424

25+
cache_path = ".cache"
26+
cache_csghub = "csghub"
27+
2528
@dataclass(frozen=True)
2629
class LocalUploadFilePaths:
2730
path_in_repo: str
@@ -184,7 +187,7 @@ def get_local_upload_paths(local_dir: Path, filename: str) -> LocalUploadFilePat
184187

185188

186189
def csghub_dir(local_dir: Path) -> Path:
187-
path = local_dir / ".cache" / "csghub"
190+
path = local_dir / cache_path / cache_csghub
188191
path.mkdir(exist_ok=True, parents=True)
189192

190193
gitignore = path / ".gitignore"

pycsghub/upload_large_folder/main.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from pycsghub.constants import DEFAULT_REVISION
1818
import os
1919
import signal
20+
import shutil
21+
from .local_folder import cache_path, cache_csghub
2022

2123
logger = logging.getLogger(__name__)
2224

@@ -67,7 +69,7 @@ def upload_large_folder_internal(
6769

6870
items = [
6971
(paths, read_upload_metadata(folder_path, paths.path_in_repo))
70-
for paths in tqdm(paths_list, desc=f"recovering from cache metadata from {folder_path}/.cache")
72+
for paths in tqdm(paths_list, desc=f"recovering from cache metadata from {folder_path}/f{cache_path}")
7173
]
7274

7375
logger.info(f"starting {num_workers} worker threads for upload tasks")
@@ -109,6 +111,13 @@ def upload_large_folder_internal(
109111

110112
print(status.current_report())
111113
logging.info("large folder upload process is complete!")
114+
115+
clean_path = os.path.join(folder_path, cache_path, cache_csghub)
116+
if os.path.exists(clean_path):
117+
try:
118+
shutil.rmtree(clean_path)
119+
except Exception as e:
120+
logging.error(f"failed to remove cache path: {e}")
112121
except KeyboardInterrupt:
113122
print("Terminated by Ctrl+C")
114123
os.kill(os.getpid(), signal.SIGTERM)

pycsghub/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
from pycsghub._token import _get_token_from_file, _get_token_from_environment
1515
from urllib.parse import quote, urlparse
1616
from pycsghub.constants import S3_INTERNAL
17+
import logging
1718

19+
logger = logging.getLogger(__name__)
1820

1921
def get_session() -> requests.Session:
2022
session = requests.Session()
@@ -221,6 +223,8 @@ def dataset_info(
221223
if files_metadata:
222224
params["blobs"] = True
223225
r = requests.get(path, headers=headers, timeout=timeout, params=params)
226+
if r.status_code != 200:
227+
logger.error(f"get dataset meta info from {path} response: {r.text}")
224228
r.raise_for_status()
225229
data = r.json()
226230
return DatasetInfo(**data)
@@ -282,6 +286,8 @@ def space_info(
282286
if files_metadata:
283287
params["blobs"] = True
284288
r = requests.get(path, headers=headers, timeout=timeout, params=params)
289+
if r.status_code != 200:
290+
logger.error(f"get space meta info from {path} response: {r.text}")
285291
r.raise_for_status()
286292
data = r.json()
287293
return SpaceInfo(**data)
@@ -349,6 +355,8 @@ def model_info(
349355
if files_metadata:
350356
params["blobs"] = True
351357
r = requests.get(path, headers=headers, timeout=timeout, params=params)
358+
if r.status_code != 200:
359+
logger.error(f"get model meta info from {path} response: {r.text}")
352360
r.raise_for_status()
353361
data = r.json()
354362
return ModelInfo(**data)

pyproject.toml

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "csghub-sdk"
7+
version = "0.7.5"
8+
description = "CSGHub SDK for downloading and uploading models, datasets, and spaces"
9+
readme = "README.md"
10+
license = { text = "Apache-2.0" }
11+
authors = [
12+
{ name = "opencsg", email = "contact@opencsg.com" }
13+
]
14+
keywords = ["ai", "machine-learning", "models", "datasets", "huggingface"]
15+
classifiers = [
16+
"Development Status :: 4 - Beta",
17+
"Intended Audience :: Developers",
18+
"License :: OSI Approved :: Apache Software License",
19+
"Operating System :: OS Independent",
20+
"Programming Language :: Python :: 3",
21+
"Programming Language :: Python :: 3.8",
22+
"Programming Language :: Python :: 3.9",
23+
"Programming Language :: Python :: 3.10",
24+
"Programming Language :: Python :: 3.11",
25+
"Programming Language :: Python :: 3.12",
26+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
27+
"Topic :: Software Development :: Libraries :: Python Modules",
28+
]
29+
requires-python = ">=3.8,<3.14"
30+
dependencies = [
31+
"typer",
32+
"typing_extensions",
33+
"huggingface_hub>=0.22.2",
34+
]
35+
36+
[project.optional-dependencies]
37+
train = [
38+
"torch",
39+
"transformers>=4.33.3",
40+
"datasets>=2.20.0"
41+
]
42+
43+
[project.scripts]
44+
csghub-cli = "pycsghub.cli:app"
45+
46+
[project.urls]
47+
Homepage = "https://github.com/opencsg/csghub-sdk"
48+
Documentation = "https://github.com/opencsg/csghub-sdk"
49+
Repository = "https://github.com/opencsg/csghub-sdk"
50+
Issues = "https://github.com/opencsg/csghub-sdk/issues"
51+
52+
[tool.setuptools.packages.find]
53+
include = ["pycsghub*"]
54+
55+
[tool.setuptools.package-data]
56+
pycsghub = ["*"]
57+
58+
[tool.setuptools.dynamic]
59+
version = { attr = "pycsghub.__version__" }

0 commit comments

Comments
 (0)