[Contribute] Migrate LeRobot VLA Models to XLeRobot
📋 Issue Description
Migrate existing VLA models from LeRobot (originally for single SO-101 arms) to work with XLeRobot dual-arm mobile platform.
🎯 What We Need
Target Models Options: DP3, ACT, Pi0, SmolVLA, GR00T
Requirements:
- Model inference on XLeRobot hardware/simulation
- Action space mapping from single arm to dual-arm + base
- Working demo with pick-and-place task
- Basic documentation
🔧 Expected Approach
- Study LeRobot VLA implementations (refer to [LeRobot official website](https://github.com/huggingface/lerobot))
- Create action space adapter for XLeRobot format
- Test with pre-trained checkpoints
- Create demo script showing task execution
📝 Acceptance Criteria
- [ ] Successfully run at least 1 VLA model on XLeRobot
- [ ] Demo video
- [ ] Basic setup documentation
- [ ] Example usage script
💡 Resources
- [LeRobot Models](https://github.com/huggingface/lerobot)
- XLeRobot Examples
- Dual-arm Control Reference
🤝 Propose Your Approach
Comment with your implementation plan! Include:
- Which models you'll start with
- How you'll handle action space mapping
- Timeline estimate
This will bring powerful VLA capabilities to our dual-arm mobile platform! 🤖
After speaking with Dor on this here's few pointers which i am putting down below:
- Should have a Sort class which will provide the direction of the sort -> also sorting will be applied from the client side (via query param like
sort=ascorsort=desc - For json_file handler the
$gtfilter on the cursor will not work -> the pagination for json_file needs to be done on the adapter itself while keeping the find like how it is (this part im still figuring out from design perspective) - Write DB and API test cases before - which i have started as it will help me with 2
New update:
In the API tests the test_that_sessions_can_be_listed remains as is and responsible to test the normal legacy non paginated response
Tests additions:
# File: /parlant/tests/api/test_sessions.py
async def test_that_sessions_are_paginated_when_listed(
async_client: httpx.AsyncClient,
container: Container
) -> None:
agents = [
await create_agent(container, "first-agent"),
]
# 10 sessions
sessions = []
for i in range(10):
session = await create_session(
container,
agent_id=agents[0].id,
title=f"session-{i}"
)
sessions.append(session)
response = await async_client.get("/sessions", params={"limit": 5})
data = response.raise_for_status().json()
assert "sessions" in data
assert "next_cursor" in data
assert "total_count" in data
assert "has_more" in data
assert len(data["sessions"]) == 5
assert data["total_count"] == 10
assert data["has_more"] is True
async def test_sessions_pagination_with_cursor(
async_client: httpx.AsyncClient,
container: Container,
) -> None:
"""Test cursor-based pagination navigation."""
agent = await create_agent(container, "test-agent")
# 7 sessions
for i in range(7):
await create_session(container, agent_id=agent.id, title=f"session-{i}")
response1 = await async_client.get("/sessions", params={"limit": 3})
data1 = response1.raise_for_status().json()
assert len(data1["sessions"]) == 3
assert data1["has_more"] is True
assert data1["next_cursor"] is not None
response2 = await async_client.get("/sessions", params={
"cursor": data1["next_cursor"],
"limit": 3
})
data2 = response2.raise_for_status().json()
assert len(data2["sessions"]) == 3
assert data2["has_more"] is True
response3 = await async_client.get("/sessions", params={
"cursor": data2["next_cursor"],
"limit": 3
})
data3 = response3.raise_for_status().json()
assert len(data3["sessions"]) == 1
assert data3["has_more"] is False
assert data3["next_cursor"] is None
# Check if any overlap
page1_ids = {s["id"] for s in data1["sessions"]}
page2_ids = {s["id"] for s in data2["sessions"]}
page3_ids = {s["id"] for s in data3["sessions"]}
assert page1_ids.isdisjoint(page2_ids)
assert page1_ids.isdisjoint(page3_ids)
assert page2_ids.isdisjoint(page3_ids)
async def test_sessions_pagination_sort_directions(
async_client: httpx.AsyncClient,
container: Container,
) -> None:
"""Test ascending vs descending sort."""
agent = await create_agent(container, "test-agent")
sessions = []
for i in range(7):
session = await create_session(container, agent_id=agent.id, title=f"session-{i}")
sessions.append(session)
await asyncio.sleep(0.015) # Small delay so entries have different creation_utc
response_desc = await async_client.get("/sessions", params={"limit": 7, "sort": "desc"})
data_desc = response_desc.raise_for_status().json()
response_asc = await async_client.get("/sessions", params={"limit": 7, "sort": "asc"})
data_asc = response_asc.raise_for_status().json()
assert len(data_desc["sessions"]) == len(data_asc["sessions"])
assert data_desc["sessions"][0]["id"] == data_asc["sessions"][-1]["id"]
assert data_desc["sessions"][-1]["id"] == data_asc["sessions"][0]["id"]
async def test_sessions_pagination_with_filters(
async_client: httpx.AsyncClient,
container: Container,
) -> None:
"""Test pagination combined with existing filters."""
agents = [
await create_agent(container, "agent-1"),
await create_agent(container, "agent-2"),
]
# Create sessions with different agents
for i in range(3):
await create_session(container, agent_id=agents[0].id, title=f"agent1-session-{i}")
for i in range(2):
await create_session(container, agent_id=agents[1].id, title=f"agent2-session-{i}")
response = await async_client.get("/sessions", params={
"agent_id": agents[0].id,
"limit": 2
})
data = response.raise_for_status().json()
assert len(data["sessions"]) == 2
assert data["total_count"] == 3
assert data["has_more"] is True
assert all(s["agent_id"] == agents[0].id for s in data["sessions"])
async def test_sessions_pagination_empty_results(
async_client: httpx.AsyncClient,
container: Container,
) -> None:
"""Test pagination with no sessions."""
response = await async_client.get("/sessions", params={"limit": 10})
data = response.raise_for_status().json()
assert data["sessions"] == []
assert data["total_count"] == 0
assert data["has_more"] is False
assert data["next_cursor"] is None
async def test_sessions_pagination_invalid_cursor(
async_client: httpx.AsyncClient,
container: Container,
) -> None:
"""Test pagination with invalid cursor."""
agent = await create_agent(container, "test-agent")
await create_session(container, agent_id=agent.id)
response = await async_client.get("/sessions", params={
"cursor": "invalid-cursor",
"limit": 10
})
data = response.raise_for_status().json()
assert len(data["sessions"]) == 1
assert data["total_count"] == 1
Tests for adapters:
async def test_sessions_retrieval(context: _TestContext, new_file: Path) -> None:
async with JSONFileDocumentDatabase(context.container[Logger], new_file) as session_db:
async with SessionDocumentStore(session_db) as session_store:
sessions = []
for i in range(10):
customer_id = CustomerId(f"test_customer_{i}")
title = f"test_title_{i}"
utc_now = datetime.now(timezone.utc)
session = await session_store.create_session(
creation_utc=utc_now,
customer_id=customer_id,
agent_id=context.agent_id,
title=title
)
sessions.append(session)
loaded_sessions = await session_store.list_sessions(agent_id=context.agent_id)
loaded_sessions = list(loaded_sessions)
assert len(loaded_sessions) == len(sessions)
loaded_session = loaded_sessions[0]
assert loaded_session.title == sessions[0].title
assert loaded_session.customer_id == sessions[0].customer_id
assert loaded_session.agent_id == context.agent_id
Working on the mongo one
@mc-dorzo should i start working on the implementation? are these tests enough for now? (do check the adapters one as im not sure if its right)
Hey team,
Firstly apologise as there isnt much progress in this issue yet. In the design phase we are contemplating a few things and appreciate inputs.
Firstly we have decided to go with cursor based pagination instead of skip and limit. We have also decided to add Sorting by using a Sort query params:
CursorQuery: TypeAlias = Annotated[
Optional[SessionId],
Query(
description="Cursor for pagination - session ID to start after",
examples=["sess_123yz"],
),
]
LimitQuery: TypeAlias = Annotated[
Optional[int],
Query(
description="Maximum number of sessions to return",
examples=[10, 50],
ge=1,
le=100,
),
]
SortQuery: TypeAlias = Annotated[
Optional[str],
Query(
description="Sort direction for the sessions to be sorted",
examples="asc",
),
]
Next have decided to add sort class like this:
from typing import NamedTuple, Literal
class Sort(NamedTuple):
field: str
direction: Literal["asc", "desc"] = "asc"
@classmethod
def by_creation_utc(cls, direction: Literal["asc", "desc"] = "asc") -> 'Sort':
return cls("creation_utc", direction)
@property
def is_ascending(self) -> bool:
return self.direction == "asc"
@property
def is_descending(self) -> bool:
return self.direction == "desc"
@property
def mongo_direction(self) -> int:
return 1 if self.is_ascending else -1
After a long thought about which field is best suitable for using as a cursor - had decided to go with the session.id as the cursor. However it comes with a challenge. Currently in the json_file document , the session ID is generated randomly so using session.id as the cursor with json_file poses a challenge. For other databases, the ID field is generally a primary index and incremental so a $gt filter would work pretty well for them, also to handle for the json_file I have decided to create a hash / encode with combination of session.id and session.creation_utc
This generated string will then be decoded and we can use the right field as cursor, for now i have decided to use creation_utc - this has its own challenges which are discussed later
This would be in the common.py inside persistence folder
def encode_cursor(session_id: str, creation_utc: str) -> str:
"""Encode session ID and creation timestamp."""
cursor_data = json.dumps({
"id": session_id,
"creation_utc": creation_utc
}, sort_keys=True)
return base64.urlsafe_b64encode(cursor_data.encode()).decode().rstrip('=')
def decode_cursor(cursor: str) -> Dict[str, str]:
"""Decode to get session ID and creation timestamp."""
cursor += '=' * (4 - len(cursor) % 4)
decoded = base64.urlsafe_b64decode(cursor).decode()
return json.loads(decoded)
In addition it will also have these classes:
@dataclass(frozen=True)
class PaginationParams:
cursor: Optional[str] = None
limit: int = 50
sort: Optional[Sort] = None
@dataclass(frozen=True)
class PaginationResult(Generic[T]):
items: Sequence[T]
next_cursor: Optional[str]
total_count: int
has_more: bool
I want to provide a way which is backwards compatible, people already using the existing Sessions API will continue to use the older one, and we will provide a migration guide to them to migrate to the new paginated API and so will add these 2 abstract methods in document_database.py
from parlant.core.persistence.common import PaginationParams, PaginationResult
class DocumentCollection(ABC, Generic[TDocument]):
# ... existing methods unchanged ...
@abstractmethod
async def find_paginated(
self,
filters: Where,
pagination: PaginationParams,
) -> PaginationResult[TDocument]:
"""Finds documents with cursor-based pagination."""
...
@abstractmethod
async def count(
self,
filters: Where,
) -> int:
"""Counts documents matching filters."""
...
Then in the SessionStore and SessionDocumentStore we make changes like this:
from parlant.core.persistence.common import PaginationParams, PaginationResult
class SessionStore(ABC):
# ... existing methods unchanged ...
@abstractmethod
async def list_sessions_paginated(
self,
agent_id: Optional[AgentId] = None,
customer_id: Optional[CustomerId] = None,
pagination: PaginationParams = PaginationParams(),
) -> PaginationResult[Session]: ...
class SessionDocumentStore(SessionStore):
# ... existing methods unchanged ...
@override
async def list_sessions_paginated(
self,
agent_id: Optional[AgentId] = None,
customer_id: Optional[CustomerId] = None,
pagination: PaginationParams = PaginationParams(),
) -> PaginationResult[Session]:
async with self._lock.reader_lock:
filters = {
**({"agent_id": {"$eq": agent_id}} if agent_id else {}),
**({"customer_id": {"$eq": customer_id}} if customer_id else {}),
}
result = await self._session_collection.find_paginated(
filters=cast(Where, filters),
pagination=pagination,
)
return PaginationResult(
items=[self._deserialize_session(doc) for doc in result.items],
next_cursor=result.next_cursor,
total_count=result.total_count,
has_more=result.has_more,
)
And then on the adapters:
mongo_db.py
from parlant.core.persistence.common import (
PaginationParams,
PaginationResult,
decode_cursor,
encode_cursor
)
class MongoDocumentCollection(DocumentCollection[TDocument]):
# ... existing methods unchanged ...
async def find_paginated(
self,
filters: Where,
pagination: PaginationParams,
) -> PaginationResult[TDocument]:
pipeline = []
if filters:
pipeline.append({"$match": filters})
if pagination.cursor and pagination.sort:
cursor_values = decode_cursor(pagination.cursor)
cursor_conditions = []
for sort_field in pagination.sort:
field_name = sort_field["field"]
if field_name in cursor_values:
op = "$gt" if sort_field["direction"] == "asc" else "$lt"
cursor_conditions.append({field_name: {op: cursor_values[field_name]}})
if cursor_conditions:
pipeline.append({"$match": {"$or": cursor_conditions}})
if pagination.sort:
sort_spec = {}
for sort_field in pagination.sort:
direction = 1 if sort_field["direction"] == "asc" else -1
sort_spec[sort_field["field"]] = direction
pipeline.append({"$sort": sort_spec})
pipeline.append({
"$facet": {
"items": [{"$limit": pagination.limit + 1}],
"total_count": [{"$count": "count"}]
}
})
result = await self._collection.aggregate(pipeline).to_list()
items = result[0]["items"]
total_count = result[0]["total_count"][0]["count"] if result[0]["total_count"] else 0
has_more = len(items) > pagination.limit
if has_more:
items = items[:pagination.limit]
next_cursor = None
if has_more and items and pagination.sort:
last_item = items[-1]
cursor_values = {sf["field"]: last_item.get(sf["field"]) for sf in pagination.sort}
next_cursor = encode_cursor(cursor_values)
return PaginationResult(
items=items,
next_cursor=next_cursor,
total_count=total_count,
has_more=has_more,
)
async def count(self, filters: Where) -> int:
return await self._collection.count_documents(filters)
json_file
from parlant.core.persistence.common import (
PaginationParams,
PaginationResult,
decode_cursor,
encode_cursor
)
class JSONFileDocumentCollection(DocumentCollection[TDocument]):
# ... existing methods unchanged ...
@override
async def find_paginated(
self,
filters: Where,
pagination: PaginationParams,
) -> PaginationResult[TDocument]:
async with self._lock.reader_lock:
filtered_docs = [doc for doc in self.documents if matches_filters(filters, doc)]
total_count = len(filtered_docs)
if pagination.sort:
for sort_field in reversed(pagination.sort):
field_name = sort_field["field"]
reverse = sort_field["direction"] == "desc"
filtered_docs.sort(key=lambda x: x.get(field_name, ""), reverse=reverse)
start_idx = 0
if pagination.cursor and pagination.sort:
cursor_values = decode_cursor(pagination.cursor)
for i, doc in enumerate(filtered_docs):
if self._is_after_cursor(doc, cursor_values, pagination.sort):
start_idx = i
break
end_idx = start_idx + pagination.limit
items = filtered_docs[start_idx:end_idx]
has_more = end_idx < len(filtered_docs)
next_cursor = None
if has_more and items and pagination.sort:
last_item = items[-1]
cursor_values = {sf["field"]: last_item.get(sf["field"]) for sf in pagination.sort}
next_cursor = encode_cursor(cursor_values)
return PaginationResult(
items=items,
next_cursor=next_cursor,
total_count=total_count,
has_more=has_more,
)
def _is_after_cursor(self, doc: TDocument, cursor_values: dict, sort: Sort) -> bool:
for sort_field in sort:
field_name = sort_field["field"]
doc_value = doc.get(field_name)
cursor_value = cursor_values.get(field_name)
if doc_value != cursor_value:
if sort_field["direction"] == "asc":
return doc_value > cursor_value
else:
return doc_value < cursor_value
return False
@override
async def count(self, filters: Where) -> int:
async with self._lock.reader_lock:
return len([doc for doc in self.documents if matches_filters(filters, doc)])
Challenge
Using the creation_utc has its own pitfalls.
At scale with over 10milion records, there can be SIMILAR or SAME records with overlapping creation_utc time stamps, with more and more number of new sessions getting created it may be possible that there can be overlaps or collisions - Its coming down to around 19 collisions per 10milion records / month.
Another challenge or performance overhead of using creation_utc is , its a string type. the $gt performs a lexicographical comparison - which goes char over char one by one.
Thank you @mc-dorzo for your help continued help, I really appreciate it.
@kichanyurd what do you think?
I did some research on the challenges to using creation_utc as the cursor and here's my findings:
Most production deployments of Parlant will use (or should use) mongodb for persistence. Having an index on the creation_utc field will dramatically improve the performance:
Case: 10 milion records in sessions collection No index:
# MongoDB has to scan EVERY document
db.sessions.find({"creation_utc": {"$lt": "2024-03-24T10:00:00Z"}})
# Scans: 10,000,000 documents
# Time: ~30+ seconds
With index:
db.sessions.createIndex({"creation_utc": -1}) # Descending index
# Same query now uses index
db.sessions.find({"creation_utc": {"$lt": "2024-03-24T10:00:00Z"}})
# Index lookup: O(log n) = ~23 operations for 10M records
# Time: ~5-50ms
Lexicographic comparison IS fast because:
- Index is pre-sorted - MongoDB doesn't compare strings during query
- Binary search - O(log n) to find position
- Sequential read - Once position found, just read next N records
Important thing to note with using indexes is it comes with its own storage overhead. For 10M records and index on creation_utc the storage overhead for the index would be :
Each entry: ~32 bytes (timestamp string + document pointer) Total index: 10M × 32 bytes = ~320MB
which is an acceptable overhead for the performance gain.
Coming to the json_file document DB - Which I feel is a completly different issue to solve as:
class JSONFileDocumentCollection(DocumentCollection[TDocument]):
def __init__(
self,
database: JSONFileDocumentDatabase,
name: str,
schema: type[TDocument],
data: Sequence[TDocument] | None = None,
) -> None:
self._database = database
self._name = name
self._schema = schema
self._op_counter = 0
self._lock = ReaderWriterLock()
self.documents = list(data) if data else []
As the data gets loaded into memory entirely, with large datasets this gets increasingly problematic. JSON file DB is fundamentally not designed for large datasets. To fix for this I would suggest moving to SQLite as the DB as the advantages are huge
let me know what you guys think.
First of all, @Agam1997, great job on the deep research—you’re one of a kind!
Two notes: 1. Could you please update the interface here for the case where we actually want to extend the find() function in DocumentCollection? We need it to remain backward compatible, but that’s possible as long as any newly introduced parameters are optional. 2. Question: I noticed you wrote creation_utc with a format that excludes microseconds. Do you think that could make a difference?
class DocumentCollection(ABC, Generic[TDocument]):
@abstractmethod
async def find(
self,
filters: Where,
sort: Optional[Sort] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> Union[Sequence[TDocument], PaginationResult[TDocument]]:
"""Finds documents. Returns PaginationResult if pagination params provided."""
...
Would look like this i suppose
class JSONFileDocumentCollection(DocumentCollection[TDocument]):
@override
async def find(
self,
filters: Where,
sort: Optional[Sort] = None,
limit: Optional[int] = None,
cursor: Optional[str] = None,
) -> Union[Sequence[TDocument], PaginationResult[TDocument]]:
is_paginated = sort is not None or limit is not None or cursor is not None
async with self._lock.reader_lock:
filtered_docs = [doc for doc in self.documents if matches_filters(filters, doc)]
if not is_paginated:
# Legacy behavior
return filtered_docs
total_count = len(filtered_docs)
if sort:
for sort_field in reversed(sort):
field_name = sort_field["field"]
reverse = sort_field["direction"] == "desc"
filtered_docs.sort(key=lambda x: x.get(field_name, ""), reverse=reverse)
if cursor and sort:
cursor_data = decode_cursor(cursor)
# Find starting position
start_idx = 0
for i, doc in enumerate(filtered_docs):
should_include = True
for sort_field in sort:
field_name = sort_field["field"]
doc_value = doc.get(field_name)
cursor_value = cursor_data.get(field_name)
if doc_value != cursor_value:
if sort_field["direction"] == "desc":
# For descending
should_include = doc_value < cursor_value
else:
# For ascending
should_include = doc_value > cursor_value
break
if should_include:
start_idx = i
break
# Apply
filtered_docs = filtered_docs[start_idx:]
items = filtered_docs[:limit] if limit else filtered_docs
has_more = limit and len(filtered_docs) > limit
next_cursor = None
if has_more and items and sort:
last_item = items[-1]
cursor_values = {sf["field"]: last_item.get(sf["field"]) for sf in sort}
next_cursor = encode_cursor(cursor_values)
return PaginationResult(
items=items,
next_cursor=next_cursor,
total_count=total_count,
has_more=has_more,
)
Is the return type Union[Sequence[TDocument], PaginationResult[TDocument]] correct? just feels very non-intuitive to me so i might've made a number of mistakes here